当前位置: 首页 > news >正文

什么是MMOE?

一、概念

        MMOE(Multi-gate Mixture-of-Experts)由MOE改进而来,是一种用于多任务学习的深度学习模型架构,特别适合处理具有多个目标任务的场景(搜广推领域常客)。它通过引入专家网络(Experts)和多门机制(Multi-gate)来实现任务间的共享和独立性,解决了多任务学习中任务冲突的问题。

        在多任务学习中,多个任务共享一个模型的部分参数(通常是底层特征),但由于任务之间可能存在冲突(例如分类任务和回归任务对特征的需求不同),直接共享参数可能导致性能下降。MMOE通过引入专家网络和门机制,动态地为每个任务分配合适的专家,从而实现任务间的协作和独立性。MMOE的架构主要由以下几个部分组成:

1、专家网络(Experts)

  • 专家网络是多个独立的子网络,每个子网络负责学习特定的特征。
  • 专家网络的输出是共享的,所有任务都可以使用这些输出。
  • 专家网络的数量可以根据具体问题设置,通常是多个。

2、门机制(Gates)

  • 每个任务都有一个独立的门机制(Gate),用于为该任务动态分配专家网络的权重。
  • 门机制是一个小型的神经网络,输入是共享的特征,输出是专家网络的权重分布。
  • 门机制的输出权重决定了每个任务如何组合专家网络的输出。

3、任务特定的塔(Task-specific Towers)

  • 每个任务都有一个独立的塔(Tower),用于处理任务特定的特征并生成最终的预测。
  • 塔的输入是门机制加权后的专家网络输出。

二、原理

        设有K个专家网络和T个任务,MMOE的数学表示如下:

1、专家网络输出

        其中,X是输入特征,是第k个专家网络。

2、门机制权重

        其中,是第t个任务的门机制,是权重向量。

3、加权组合专家输出

        其中,是任务t对专家k的权重。

4、任务塔生成预测

        其中,是任务t的塔。

三、python实现

        这里直接给出MMOE的构建过程,假设我们有三个任务。后续的训练过程与普通神经网络一致,唯一需要注意的是损失函数的构建:如果每个任务使用不同的损失函数,则分别计算损失之后合并为总损失即可;否则直接使用同一个损失函数计算总损失

import torch
import torch.nn as nnclass Expert(nn.Module):def __init__(self, input_dim, hidden_dim):super().__init__()self.net = nn.Sequential(nn.Linear(input_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, hidden_dim))def forward(self, x):return self.net(x)class MMoE(nn.Module):def __init__(self, input_dim, num_experts=4, num_tasks=3):super().__init__()self.experts = nn.ModuleList([Expert(input_dim, 64) for _ in range(num_experts)])# 每个任务独立门控self.gates = nn.ModuleList([nn.Sequential(nn.Linear(input_dim, num_experts),nn.Softmax(dim=-1)) for _ in range(num_tasks)])# 任务专属塔层self.towers = nn.ModuleList([nn.Sequential(nn.Linear(64, 32),nn.ReLU(),nn.Linear(32, 1)) for _ in range(num_tasks)])def forward(self, x):expert_outputs = torch.stack([e(x) for e in self.experts], dim=1)  # [batch, experts, dim]outputs = []for gate, tower in zip(self.gates, self.towers):weights = gate(x).unsqueeze(-1)  # [batch, experts, 1]combined = (expert_outputs * weights).sum(1)  # [batch, dim]outputs.append(tower(combined).squeeze())return torch.stack(outputs, dim=1)  # [batch, tasks]# 使用示例
model = MMoE(input_dim=128)
x = torch.randn(32, 128)  # 批量32,特征128
y_pred = model(x)          # 输出形状[32,3]# 虚拟三任务标签
y_true = torch.rand(32, 3)  
loss_fn = nn.BCEWithLogitsLoss()
loss = loss_fn(y_pred, y_true)
print(loss)

 

http://www.xdnf.cn/news/13501.html

相关文章:

  • 坐标上海,20~40K的面试强度
  • Android Studio 常见报错
  • 计算机网络——应用层
  • 济南通过首个备案生活服务大模型,打造行业新标杆
  • 【人工智能】Prompt攻击与防范策略总结
  • 2025年03月中国电子学会青少年软件编程(Python)等级考试试卷(三级)答案 + 解析
  • ELF2开发板的ubuntu系统的ax200 wifi配网
  • Vue 3.0 Composition API 与 Vue 2.x Options API 的区别
  • 8.Rust+Axum 数据库集成实战:从 ORM 选型到用户管理系统开发
  • 2025MathorcupC题 音频文件的高质量读写与去噪优化 保姆级教程讲解|模型讲解
  • Docker中镜像、容器、仓库三者之间的关系
  • 第 8 期:条件生成 DDPM:让模型“听话”地画图!
  • Hadoop的三大结构及各自的作用?
  • TDengine Restful 接口API
  • excel解析图片pdf附件不怕
  • ESP8266简单介绍
  • 2025年山东燃气瓶装送气工考试真题练习
  • MCP协议量子加密实践:基于QKD的下一代安全通信(2025深度解析版)
  • 从数字化到智能化,百度 SRE 数智免疫系统的演进和实践
  • MCP(Model Context Protocol 模型上下文协议)科普
  • vue 中formatter
  • 2025-04-18 李沐深度学习3 —— 线性代数
  • yarn的三大组件及各自作用
  • easyexcel使用模板填充excel坑点总结
  • Kotlin协程Semaphore withPermit约束并发任务数量
  • chili3d调试笔记3 加入c++ 大模型对话方法 cmakelists精读
  • PY32F003+TIM+外部中断实现对1527解码
  • 【Test Test】灰度化和二值化处理图像
  • 6TOPS算力NPU加持!RK3588如何重塑8K显示的边缘计算新边界
  • 嵌入式音视频开发指南:从MPP框架到QT实战全解析