什么是MMOE?
一、概念
MMOE(Multi-gate Mixture-of-Experts)由MOE改进而来,是一种用于多任务学习的深度学习模型架构,特别适合处理具有多个目标任务的场景(搜广推领域常客)。它通过引入专家网络(Experts)和多门机制(Multi-gate)来实现任务间的共享和独立性,解决了多任务学习中任务冲突的问题。
在多任务学习中,多个任务共享一个模型的部分参数(通常是底层特征),但由于任务之间可能存在冲突(例如分类任务和回归任务对特征的需求不同),直接共享参数可能导致性能下降。MMOE通过引入专家网络和门机制,动态地为每个任务分配合适的专家,从而实现任务间的协作和独立性。MMOE的架构主要由以下几个部分组成:
1、专家网络(Experts)
- 专家网络是多个独立的子网络,每个子网络负责学习特定的特征。
- 专家网络的输出是共享的,所有任务都可以使用这些输出。
- 专家网络的数量可以根据具体问题设置,通常是多个。
2、门机制(Gates)
- 每个任务都有一个独立的门机制(Gate),用于为该任务动态分配专家网络的权重。
- 门机制是一个小型的神经网络,输入是共享的特征,输出是专家网络的权重分布。
- 门机制的输出权重决定了每个任务如何组合专家网络的输出。
3、任务特定的塔(Task-specific Towers)
- 每个任务都有一个独立的塔(Tower),用于处理任务特定的特征并生成最终的预测。
- 塔的输入是门机制加权后的专家网络输出。
二、原理
设有K个专家网络和T个任务,MMOE的数学表示如下:
1、专家网络输出
其中,X是输入特征,是第k个专家网络。
2、门机制权重
其中,是第t个任务的门机制,
是权重向量。
3、加权组合专家输出
其中,是任务t对专家k的权重。
4、任务塔生成预测
其中,是任务t的塔。
三、python实现
这里直接给出MMOE的构建过程,假设我们有三个任务。后续的训练过程与普通神经网络一致,唯一需要注意的是损失函数的构建:如果每个任务使用不同的损失函数,则分别计算损失之后合并为总损失即可;否则直接使用同一个损失函数计算总损失。
import torch
import torch.nn as nnclass Expert(nn.Module):def __init__(self, input_dim, hidden_dim):super().__init__()self.net = nn.Sequential(nn.Linear(input_dim, hidden_dim),nn.ReLU(),nn.Linear(hidden_dim, hidden_dim))def forward(self, x):return self.net(x)class MMoE(nn.Module):def __init__(self, input_dim, num_experts=4, num_tasks=3):super().__init__()self.experts = nn.ModuleList([Expert(input_dim, 64) for _ in range(num_experts)])# 每个任务独立门控self.gates = nn.ModuleList([nn.Sequential(nn.Linear(input_dim, num_experts),nn.Softmax(dim=-1)) for _ in range(num_tasks)])# 任务专属塔层self.towers = nn.ModuleList([nn.Sequential(nn.Linear(64, 32),nn.ReLU(),nn.Linear(32, 1)) for _ in range(num_tasks)])def forward(self, x):expert_outputs = torch.stack([e(x) for e in self.experts], dim=1) # [batch, experts, dim]outputs = []for gate, tower in zip(self.gates, self.towers):weights = gate(x).unsqueeze(-1) # [batch, experts, 1]combined = (expert_outputs * weights).sum(1) # [batch, dim]outputs.append(tower(combined).squeeze())return torch.stack(outputs, dim=1) # [batch, tasks]# 使用示例
model = MMoE(input_dim=128)
x = torch.randn(32, 128) # 批量32,特征128
y_pred = model(x) # 输出形状[32,3]# 虚拟三任务标签
y_true = torch.rand(32, 3)
loss_fn = nn.BCEWithLogitsLoss()
loss = loss_fn(y_pred, y_true)
print(loss)