MMoE: Enhancing Multimodal Models with Mixtures of Multimodal Interaction Experts

基本信息

注:正文实际讨论的是 Ma et al. KDD’18 的经典 MMoE(Multi-gate Mixture-of-Experts),与 frontmatter 标题对应的 2023 多模态 MMoE 不是同一篇,这里按正文实际涉及的经典版填写。

字段 内容
标题 Modeling Task Relationships in Multi-task Learning with Multi-gate Mixture-of-Experts
作者 Jiaqi Ma, Zhe Zhao, Xinyang Yi, Jilin Chen, Lichan Hong, Ed H. Chi
机构 University of Michigan; Google
年份 2018 (KDD’18)
方向 Multi-Task Learning, Mixture-of-Experts, Multi-Gate
场景 多任务推荐排序中的负迁移缓解
会议 https://dl.acm.org/doi/10.1145/3219819.3220007

KDD 2018 Google

ESMM 专注于解决转化率预估中的样本偏差和稀疏问题

MMoEPLE 专注于解决多任务之间的相关性与冲突问题

主要为了解决多个任务之间相关性低导致的模型效果下降问题(负迁移问题)

主要解决办法:引入Multi-gated MoE结构,将传统Share Bottom然后接各自任务塔的方式改变为由多个专家组成的MoE层,每个任务配备一个单独的门控,经过门控加权后的专家混合输出被送到各自的任务塔,计算出最终的预测结果。

优势:1.虽然引入MoE但是门控很轻量级,专家之间还是可以共享,因此参数量增加不大,同时仍能进行一定的迁移学习。2.训练更稳定

激活函数:ReLU(MLP),Softmax(Gated MLP)

损失函数:BCE

评价指标:线下AUC R-Squared MSE, 线上CTR、观看时长、点赞率

注意:MMoE通常只用一层专家层,在大规模实验中,作者也只是将Share Bottom的顶层替换成了一层MMoE层。

缺点:

  1. 跷跷板现象 (Seesaw Phenomenon):尽管缓解了负迁移,但在极其复杂的任务组合下,仍可能出现“一个任务提升,另一个任务显著下降”的现象。因为所有专家在理论上对所有任务都是“可见”的,主导任务可能会“劫持”大部分专家的梯度更新。
  2. 门控极化:在训练初期,门控可能过早收敛到某几个专家(Winner-take-all),导致其他专家得不到充分训练(这一点通常需要配合 Dropout 或负载均衡 Loss 来缓解)。

关键部分代码实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
import torch.nn as nn
import torch.nn.functional as F

class Expert(nn.Module):
def __init__(self, input_dim, hidden_dim):
super(Expert, self).__init__()
# 论文中专家通常是简单的单层或多层 MLP
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.ReLU(), # [cite: 282]
# 工业界常在此处加 Dropout 或 BatchNormalization
)

def forward(self, x):
return self.net(x)

class MMoE(nn.Module):
def __init__(self, input_dim, num_experts, expert_dim, num_tasks, tower_dims):
super(MMoE, self).__init__()
self.num_experts = num_experts
self.num_tasks = num_tasks

# 1. 初始化专家网络 (所有任务共享)
self.experts = nn.ModuleList([
Expert(input_dim, expert_dim) for _ in range(num_experts)
])

# 2. 初始化门控网络 (每个任务一个独立的门控)
# 门控就是一个线性层,输入是特征,输出是专家数量维度的logits
# 公式:g^k(x) = softmax(W_gk * x)
self.gates = nn.ModuleList([
nn.Linear(input_dim, num_experts) for _ in range(num_tasks)
])

# 3. 初始化任务塔 (每个任务一个独立的塔)
# 塔的输入维度等于专家的输出维度
self.towers = nn.ModuleList([
nn.Sequential(
nn.Linear(expert_dim, tower_dims[0]),
nn.ReLU(),
nn.Linear(tower_dims[0], 1) # 假设是二分类 CTR/CVR
) for _ in range(num_tasks)
])

def forward(self, x):
# --- A. 计算所有专家的输出 ---
# outputs shape: [batch_size, num_experts, expert_dim]
expert_outputs = [expert(x) for expert in self.experts]
expert_outputs = torch.stack(expert_outputs, dim=1)

final_outputs = []

# --- B. 为每个任务计算特定的加权输入 ---
for i in range(self.num_tasks):
# 1. 计算门控权重
# gate_logits shape: [batch_size, num_experts]
gate_logits = self.gates[i](x)

# 关键:必须使用 Softmax 归一化权重
gate_weights = F.softmax(gate_logits, dim=1)

# 2. 扩展维度以便进行广播乘法
# gate_weights shape 变为: [batch_size, num_experts, 1]
gate_weights = gate_weights.unsqueeze(2)

# 3. 加权求和 (Weighted Sum)
# 公式:f^k(x) = sum(g(x)_i * f_i(x)) [cite: 279]
# element-wise product: [batch, num_experts, 1] * [batch, num_experts, expert_dim]
# sum dim=1: 结果 shape 为 [batch_size, expert_dim]
weighted_expert_output = torch.sum(expert_outputs * gate_weights, dim=1)

# --- C. 输入到对应的任务塔 ---
task_output = self.towers[i](weighted_expert_output)
final_outputs.append(task_output)

return final_outputs