DeepSeek-V4: Towards Highly Efficient Million-Token Context Intelligence
TL;DR
研究动机
推理模型确立了 test-time scaling 这一范式,但标准注意力的计算量随序列长度二次增长,成为超长上下文与长推理过程的瓶颈;同时,智能体、跨文档分析等长程任务也要求高效支持超长上下文。DeepSeek-V4 的目标即是打破超长(百万 token)上下文的效率壁垒 ,使长上下文与测试时扩展在工程上可行,以支持 1M 上下文。
test-time scaling(测试时扩展) :传统提升模型靠"训练时扩展"——加大参数、加多数据、加训练算力,收益在训练阶段获得。test-time scaling 则在模型固定 的前提下,于推理时 多花算力来换取更好的答案,即让模型"想得更久":生成更长的思维链、采样多条候选再投票或取最优(majority vote / best-of-N N N )、自我验证并修订、或做搜索。答案质量随推理算力(推理 token 数 / 采样数)大致单调上升,这正是推理模型(如 o1、DeepSeek-R1 一系)的范式。
它与本报告的关系在于:更长的推理意味着更多的生成与被注意 token,即更长的有效上下文 ;而注意力的二次成本使长推理变得昂贵——这正是 V4 用 CSA/HCA 压缩注意力要攻克的瓶颈。换言之,高效的超长上下文是继续做 test-time scaling 的前提 。
两个模型
DeepSeek-V4-Flash
DeepSeek-V4-Pro
总参数 / 激活参数
284B / 13B
1.6T / 49B
层数 / 隐藏维 d d d
43 / 4096
61 / 7168
注意力
前 2 层稠密滑窗,其余 CSA/HCA 交替
前 2 层 HCA,其余 CSA/HCA 交替
CSA:m m m / top-k k k
4 / 512
4 / 1024
HCA:m ′ m' m ′
128
128
注意力头数 n h n_h n h / 头维 c c c
64 / 512
128 / 512
MoE 专家(共享 + 路由,激活)
1 + 256,激活 6
1 + 384,激活 6
MTP 深度 / mHC 扩展 n h c n_{hc} n h c
1 / 4
1 / 4
预训练 token
32T
33T
两者均原生支持 100 万 token 上下文。
效率提升
在 1M 上下文下,V4-Pro 的单 token 推理 FLOPs 约为 V3.2 的 27% 、KV cache 约 10% ;V4-Flash 更进一步,约为 10% 与 7% 。相对常规 BF16 GQA8(num_heads128)基线,1M 场景 KV cache 可压到约 2% 。
核心贡献
保留自 V3 :DeepSeekMoE、MTP。
架构新增 :流形约束超连接 Manifold-Constrained Hyper-Connections(mHC)强化残差、Compressed Sparse Attention (CSA) + Heavily Compressed Attenion (HCA) 混合注意力提升长上下文效率、Muon 作主优化器。
基础设施 :融合的 MoE kernel(通信-计算重叠)、TileLang DSL、batch-invariant 确定性 kernel、面向 Muon/mHC 的训练框架与异构 KV cache 推理框架;后训练引入 FP4 量化感知训练。
训练与后训练 :32T+ token 原生 1M 上下文预训练;后训练以"领域专家 + 在线策略蒸馏(OPD)"取代混合强化学习。
能力定位
知识维度,V4-Pro-Max 在开源模型中达到新 SOTA,但仍落后 Gemini-3.1-Pro;推理维度约落后前沿闭源 3–6 个月;智能体维度与 Kimi-K2.6、GLM-5.1 同档,内部评测超过 Claude Sonnet 4.5、接近 Opus 4.5;长上下文维度,1M 下 MRCR 超过 Gemini-3.1-Pro、落后 Claude Opus 4.6。
本报告结构
章节
对应报告
状态
一、概述
Introduction
✅ 本节
二、架构
§2 Architecture
✅ 详写
三、基础设施
§3 General Infrastructures
🚧 TODO
四、预训练
§4 Pre-Training
🚧 TODO
五、后训练
§5 Post-Training
🚧 TODO
六、结论与局限
§6 Conclusion
🚧 TODO
二、Architecture
V4 保留 V3 的 DeepSeekMoE 与 MTP,新增三项改动:mHC 强化残差连接、CSA/HCA 混合注意力、Muon 优化器。
2.1 MoE 与 MTP
DeepSeekMoE
MoE 的基本思想。 标准 Transformer 每层有一个 FFN,每个 token 都要完整过一遍。MoE 把这个单一 FFN 换成许多并行的小 FFN(称为"专家"),但每个 token 只过其中少数几个 ——由一个轻量的"路由器"挑选。这样总参数可以做得很大(专家很多),而每个 token 的实际计算量只取决于被激活的少数专家。V4-Pro 的"1.6T 总参 / 49B 激活"正是这个意思:384 个路由专家里,每个 token 只激活 6 个。
两类专家。 DeepSeekMoE 把专家分成两类(设共享专家 N s N_s N s 个、路由专家 N r N_r N r 个):
共享专家 (记 F F N 1 ( s ) , … , F F N N s ( s ) \mathrm{FFN}^{(s)}_1,\dots,\mathrm{FFN}^{(s)}_{N_s} F F N 1 ( s ) , … , F F N N s ( s ) ):每个 token 都经过,不参与任何选择;
路由专家 (记 F F N 1 ( r ) , … , F F N N r ( r ) \mathrm{FFN}^{(r)}_1,\dots,\mathrm{FFN}^{(r)}_{N_r} F F N 1 ( r ) , … , F F N N r ( r ) ):每个 token 只经过其中 K r K_r K r 个,由路由器按 token 挑选。"路由(route)“即"把每个 token 分发到选中的那几个专家”——所以只有路由专家 需要打分与选择,共享专家无条件全过。
("细粒度"指把路由专家切得更多、更小,提高"选哪几个组合"的灵活度;共享专家则承载通用计算,让路由专家更易分工。)
路由器(gate)。 路由器是一个独立的小打分网络,为每个路由专家 配一个中心向量 e i ∈ R d \mathbf e_i\in\mathbb{R}^d e i ∈ R d (它们就是 gate 矩阵的各行)。请注意:e i \mathbf e_i e i 是路由器的参数 ,与路由专家 i i i 自己的 FFN 权重 F F N i ( r ) \mathrm{FFN}^{(r)}_i F F N i ( r ) 是两套完全不同的参数 ——e i \mathbf e_i e i 只用于打分/选路,F F N i ( r ) \mathrm{FFN}^{(r)}_i F F N i ( r ) 才做实际计算。
路由计算过程 (对 token t t t ,MoE 层输入为 u t \mathbf u_t u t ):
第一步,打分。 对每个路由专家 i i i ,取 u t \mathbf u_t u t 与其中心向量 e i \mathbf e_i e i 的内积,再过激活函数得亲和度 s i , t s_{i,t} s i , t :
z i , t = u t ⊤ e i z_{i,t}=\mathbf u_t^\top\mathbf e_i
z i , t = u t ⊤ e i
V3 用 Sigmoid:
s i , t = σ ( z i , t ) = 1 1 + e − z i , t ∈ ( 0 , 1 ) s_{i,t}=\sigma(z_{i,t})=\frac{1}{1+e^{-z_{i,t}}}\in(0,1)
s i , t = σ ( z i , t ) = 1 + e − z i , t 1 ∈ ( 0 , 1 )
V4 改用 S o f t p l u s \sqrt{\mathrm{Softplus}} S o f t p l u s (其中 S o f t p l u s ( x ) = ln ( 1 + e x ) \mathrm{Softplus}(x)=\ln(1+e^x) S o f t p l u s ( x ) = ln ( 1 + e x ) ):
s i , t = S o f t p l u s ( z i , t ) = ln ( 1 + e z i , t ) ∈ ( 0 , ∞ ) s_{i,t}=\sqrt{\mathrm{Softplus}(z_{i,t})}=\sqrt{\ln\!\big(1+e^{z_{i,t}}\big)}\in(0,\infty)
s i , t = S o f t p l u s ( z i , t ) = ln ( 1 + e z i , t ) ∈ ( 0 , ∞ )
两者下端都趋于 0,区别在上端:Sigmoid 饱和于 1,S o f t p l u s \sqrt{\mathrm{Softplus}} S o f t p l u s 无上界(随 z z z 增长约为 z \sqrt{z} z )。
第二步,选择 K r K_r K r 个路由专家。 取亲和度最高的 K r K_r K r 个(实际还带一个均衡偏置 b i b_i b i ,其作用与更新见下节):
S t = T o p k ( { s i , t + b i } i = 1 N r , K r ) \mathcal S_t=\mathrm{Topk}\big(\{\,s_{i,t}+b_i\,\}_{i=1}^{N_r},\ K_r\big)
S t = T o p k ( { s i , t + b i } i = 1 N r , K r )
第三步,门控权重。 对被选中的专家 ,把它们的原始 亲和度归一化(不带偏置项b,未选中的为 0):
g i , t = s i , t ∑ j ∈ S t s j , t ( i ∈ S t ) , g i , t = 0 ( i ∉ S t ) g_{i,t}=\frac{s_{i,t}}{\sum_{j\in\mathcal S_t}s_{j,t}}\ \ (i\in\mathcal S_t),\qquad g_{i,t}=0\ \ (i\notin\mathcal S_t)
g i , t = ∑ j ∈ S t s j , t s i , t ( i ∈ S t ) , g i , t = 0 ( i ∈ / S t )
第四步,输出。 token 过全部共享专家,加上选中的路由专家(按 g i , t g_{i,t} g i , t 加权),再加残差:
h t = u t + ∑ i = 1 N s F F N i ( s ) ( u t ) + ∑ i ∈ S t g i , t F F N i ( r ) ( u t ) \mathbf h_t=\mathbf u_t+\sum_{i=1}^{N_s}\mathrm{FFN}^{(s)}_i(\mathbf u_t)+\sum_{i\in\mathcal S_t} g_{i,t}\,\mathrm{FFN}^{(r)}_i(\mathbf u_t)
h t = u t + i = 1 ∑ N s F F N i ( s ) ( u t ) + i ∈ S t ∑ g i , t F F N i ( r ) ( u t )
负载均衡:无辅助损失的偏置法 + 序列级损失
若路由长期集中于少数专家,各专家被使用的频次会差异很大;当专家分布在不同设备上做并行时,这会造成设备间计算量不均,整体被最忙的设备拖慢。DeepSeek(自 V3 起、V4 沿用)的处理分两部分。
(1) 无辅助损失的偏置法。 上面"第二步"里给每个路由专家加的偏置标量 b i b_i b i ,就是均衡的主力。它只参与 top-K r K_r K r 选择,不进入门控权重 (门控权重 g i , t g_{i,t} g i , t 仍由原始亲和度 s i , t s_{i,t} s i , t 归一化)。b i b_i b i 不经反向传播更新,而由一条规则按各专家的实际负载调整(γ \gamma γ 为更新速度,V4 取 0.001 0.001 0 . 0 0 1 ):
b i ← b i − γ ( 专家 i 过载 ) , b i ← b i + γ ( 专家 i 欠载 ) b_i\leftarrow b_i-\gamma\ (\text{专家 }i\ \text{过载}),\qquad b_i\leftarrow b_i+\gamma\ (\text{专家 }i\ \text{欠载})
b i ← b i − γ ( 专家 i 过载 ) , b i ← b i + γ ( 专家 i 欠载 )
由于 b i b_i b i 既不进入门控权重、也不经梯度更新,它改变各专家被选中的频次,但不向模型参数引入额外梯度。这正是"无辅助损失"的确切含义——区别于传统那种会产生梯度、并与语言建模目标竞争的均衡损失。
(2) 序列级均衡损失。 偏置在整个 batch、较长时间尺度上均衡负载,但不保证单条序列内部均衡。为此再加一个权重很小的损失:
L bal = α ∑ i = 1 N r f i P i \mathcal{L}_{\text{bal}}=\alpha\sum_{i=1}^{N_r} f_i\,P_i
L bal = α i = 1 ∑ N r f i P i
其中 f i f_i f i 为该序列中被路由到专家 i i i 的 token 比例,P i P_i P i 为其平均门控概率,α \alpha α 取极小值(V4 为 0.0001 0.0001 0 . 0 0 0 1 )。它仅用于抑制单序列内部的极端不均衡,对主目标的梯度影响可忽略。
V4 相对 V3 的几处 MoE 细节调整
亲和度激活函数 S i g m o i d → S o f t p l u s \mathrm{Sigmoid}\to\sqrt{\mathrm{Softplus}} S i g m o i d → S o f t p l u s (公式见上):去掉了 Sigmoid 在上端的饱和。
取消路由目标节点数的限制,并重设并行 :V3 为省跨节点通信,限制每个 token 最多发往 M M M 个节点;V4 取消这一限制,配合重新设计的并行策略保持效率。
前几层 MoE 改用哈希路由 :V4 把最初几个 block 原本的稠密 FFN 也换成 MoE(于是所有 block 都是 MoE ),但最前面 3 个 MoE 层用固定的哈希路由 ——按 token ID 经预定义哈希函数直接指定专家,而非用学习出来的路由器。早期层的学习路由往往不稳、意义弱,用确定性哈希可廉价地得到稳定且均衡的分配。
MTP:与 V3 完全一致
常规语言模型训练里,每个位置只学"预测下一个 token",一个位置只有一个监督信号。MTP 让每个位置额外再预测更靠后的 token ,于是一个位置有多个监督信号。两个好处:(1) 监督更密,数据效率更高;(2) 促使隐表示提前编码未来 token 的信息(因为要预测更远的 token)。它是训练期的辅助目标 ,推理时可丢弃。
先约定几个部件 (除特别说明外都是 d d d 维,d d d 为隐藏维):
h i h_i h i :主模型在位置 i i i 算出的隐向量。常规训练就是用它预测下一个 token t i + 1 t_{i+1} t i + 1 ;
E m b ( t ) \mathrm{Emb}(t) E m b ( t ) :token t t t 的 embedding(查表得到);
O u t H e a d ( ⋅ ) \mathrm{OutHead}(\cdot) O u t H e a d ( ⋅ ) :输出头,把一个隐向量映成词表上的概率分布(线性层 + softmax);
T R M k \mathrm{TRM}_k T R M k :第 k k k 个 MTP 模块专属的标准 Transformer block(自注意力 + 前馈),结构与主干层相同、单独训练;
M k M_k M k :第 k k k 个 MTP 模块的投影矩阵(d × 2 d d\times 2d d × 2 d ),把拼接出的 2 d 2d 2 d 维向量压回 d d d 维。
其中 embedding 表与输出头都与主模型共享 ,所以一个 MTP 模块只新增"1 个 T R M k \mathrm{TRM}_k T R M k + 1 个 M k M_k M k "。
用 D D D 记 MTP 的深度,即每个位置额外预测的 token 数,也等于串联的 MTP 模块数;T T T 记序列长度。V4 取 D = 1 D=1 D = 1 ,即只加一个 MTP 模块。
下面给出第 k k k 个模块(k = 1 , … , D k=1,\dots,D k = 1 , … , D )的通式。该模块在位置 i i i 处接收前一深度的隐向量 h i k − 1 h_i^{k-1} h i k − 1 (约定 h i 0 = h i h_i^{0}=h_i h i 0 = h i 为主模型隐向量)与第 i + k i+k i + k 个 token 的 embedding,预测 token t i + k + 1 t_{i+k+1} t i + k + 1 ,分三步。
第一步,把 h i k − 1 h_i^{k-1} h i k − 1 与真实 token t i + k t_{i+k} t i + k 的 embedding 各自 RMSNorm 后拼接,再用 M k M_k M k 压回 d d d 维:
h i ′ k = M k [ R M S N o r m ( h i k − 1 ) ; R M S N o r m ( E m b ( t i + k ) ) ] h_i^{\prime k}=M_k\big[\,\mathrm{RMSNorm}(h_i^{k-1})\;;\;\mathrm{RMSNorm}(\mathrm{Emb}(t_{i+k}))\,\big]
h i ′ k = M k [ R M S N o r m ( h i k − 1 ) ; R M S N o r m ( E m b ( t i + k ) ) ]
第二步,整段序列送入该模块专属的 Transformer block(T R M k \mathrm{TRM}_k T R M k 作用于切片 h 1 : T − k ′ k h_{1:T-k}^{\prime k} h 1 : T − k ′ k ,下标 1 : T − k 1{:}T{-}k 1 : T − k 即位置 1 1 1 到 T − k T-k T − k ;因果自注意力使位置 i i i 的输出依赖 h ≤ i ′ k h_{\le i}^{\prime k} h ≤ i ′ k ):
h 1 : T − k k = T R M k ( h 1 : T − k ′ k ) h_{1:T-k}^{k}=\mathrm{TRM}_k(h_{1:T-k}^{\prime k})
h 1 : T − k k = T R M k ( h 1 : T − k ′ k )
第三步,用共享输出头给出对 t i + k + 1 t_{i+k+1} t i + k + 1 的预测分布:
P i + k + 1 k = O u t H e a d ( h i k ) P_{i+k+1}^{k}=\mathrm{OutHead}(h_i^{k})
P i + k + 1 k = O u t H e a d ( h i k )
V4 的情形(D = 1 D=1 D = 1 ,即只取 k = 1 k=1 k = 1 、h i 0 = h i h_i^{0}=h_i h i 0 = h i )。 主模型用 h i h_i h i 预测下一个 token t i + 1 t_{i+1} t i + 1 (常规语言模型损失);MTP 模块额外预测下下个 token t i + 2 t_{i+2} t i + 2 。例如处理到位置 5:主模型预测第 6 个 token;MTP 模块把 h 5 h_5 h 5 与第 6 个真实 token 的 embedding 拼起来,预测第 7 个 token。
第一步喂进去的是真实 的 t i + 1 t_{i+1} t i + 1 (取自训练数据,不是模型的猜测),所以预测 t i + 2 t_{i+2} t i + 2 是在"已知真实 t i + 1 t_{i+1} t i + 1 “的条件下进行的。这就是"保留因果链”,区别于"用多个独立头一次性并行猜多个 token"的做法。
关于 D > 1 D>1 D > 1 (V4 未采用,列出以便对照)。 上面三式对任意 k k k 通用:k k k 是第几个额外 token 的下标、i i i 是 token 在序列中的位置。串联 D D D 个模块时,第 k k k 个模块预测"再往后第 k k k 个"额外 token t i + k + 1 t_{i+k+1} t i + k + 1 ,其输入隐向量 h i k − 1 h_i^{k-1} h i k − 1 取自前一个模块的输出(k = 1 k=1 k = 1 时即主模型 h i 0 = h i h_i^{0}=h_i h i 0 = h i );各模块有独立的 T R M k \mathrm{TRM}_k T R M k 、M k M_k M k ,而 embedding 表与输出头始终共享。
损失。 每个深度 k k k 先各算一个交叉熵(对深度 k k k 所有有监督的目标位置求和,P i k [ t i ] P_i^{k}[t_i] P i k [ t i ] 为第 k k k 个模块赋予位置 i i i 真实 token t i t_i t i 的概率):
L MTP k = C r o s s E n t r o p y ( P 2 + k : T + 1 k , t 2 + k : T + 1 ) = − 1 T ∑ i = 2 + k T + 1 log P i k [ t i ] \mathcal L_{\text{MTP}}^{k}=\mathrm{CrossEntropy}(P_{2+k:T+1}^{k},\,t_{2+k:T+1})=-\frac{1}{T}\sum_{i=2+k}^{T+1}\log P_i^{k}[t_i]
L MTP k = C r o s s E n t r o p y ( P 2 + k : T + 1 k , t 2 + k : T + 1 ) = − T 1 i = 2 + k ∑ T + 1 log P i k [ t i ]
再按深度数 D D D 取平均 、乘权重 λ \lambda λ ,作为附加项加到主损失上:
L MTP = λ D ∑ k = 1 D L MTP k \mathcal L_{\text{MTP}}=\frac{\lambda}{D}\sum_{k=1}^{D}\mathcal L_{\text{MTP}}^{k}
L MTP = D λ k = 1 ∑ D L MTP k
(D = 1 D=1 D = 1 时即单个交叉熵乘 λ \lambda λ ;V4 取 λ = 0.3 \lambda=0.3 λ = 0 . 3 ,学习率衰减阶段降到 0.1 0.1 0 . 1 。)
推理时。 MTP 只服务于训练。推理可直接丢弃这些模块,主模型照常独立运行,零额外开销;也可把它当作投机解码的草稿器以降低延迟(接受/拒绝采样原理见本博客 投机采样 / 投机解码 )。
2.2 Manifold-Constrained Hyper-Connections
标准 Hyper-Connections:把残差流展宽为 n h c n_{hc} n h c 条
标准残差 x l + 1 = x l + F l ( x l ) x_{l+1}=x_l+\mathcal{F}_l(x_l) x l + 1 = x l + F l ( x l ) 只有一条流、固定权重 + 1 +1 + 1 ,在梯度通畅与表示坍缩之间存在张力。超连接(HC)把残差流由 R d \mathbb{R}^{d} R d 展宽到 R n h c × d \mathbb{R}^{n_{hc}\times d} R n h c × d ——即 n h c n_{hc} n h c 条并行 的 d d d 维残差流(V4 取 n h c = 4 n_{hc}=4 n h c = 4 ):
X l = [ x l , 1 ; … ; x l , n h c ] ⊤ ∈ R n h c × d X_l=[\mathbf{x}_{l,1};\dots;\mathbf{x}_{l,n_{hc}}]^\top\in\mathbb{R}^{n_{hc}\times d}
X l = [ x l , 1 ; … ; x l , n h c ] ⊤ ∈ R n h c × d
引入三个映射 A l ∈ R 1 × n h c A_l\in\mathbb{R}^{1\times n_{hc}} A l ∈ R 1 × n h c 、B l ∈ R n h c × n h c B_l\in\mathbb{R}^{n_{hc}\times n_{hc}} B l ∈ R n h c × n h c 、C l ∈ R n h c × 1 C_l\in\mathbb{R}^{n_{hc}\times 1} C l ∈ R n h c × 1 ,更新为:
X l + 1 = B l X l + C l F l ( A l X l ) (1) X_{l+1}=B_l X_l+C_l\,\mathcal{F}_l(A_l X_l)\tag{1}
X l + 1 = B l X l + C l F l ( A l X l ) ( 1 )
n h c n_{hc} n h c 是残差流的并行条数,不改变层内计算维度。逐项核对式 (1) 的形状即可看清分工:
运算
形状
作用
A l X l A_l X_l A l X l
( 1 × n h c ) ( n h c × d ) = 1 × d (1{\times}n_{hc})(n_{hc}{\times}d)=1{\times}d ( 1 × n h c ) ( n h c × d ) = 1 × d
把 n h c n_{hc} n h c 条流加权合成 1 个 d d d 维向量,喂给层
F l ( ⋅ ) \mathcal{F}_l(\cdot) F l ( ⋅ )
d → d d\to d d → d
注意力 / MoE 正常在 d d d 维计算
C l F l ( ⋅ ) C_l\,\mathcal{F}_l(\cdot) C l F l ( ⋅ )
( n h c × 1 ) ( 1 × d ) = n h c × d (n_{hc}{\times}1)(1{\times}d)=n_{hc}{\times}d ( n h c × 1 ) ( 1 × d ) = n h c × d
把层输出广播回 n h c n_{hc} n h c 条流
B l X l B_l X_l B l X l
( n h c × n h c ) ( n h c × d ) = n h c × d (n_{hc}{\times}n_{hc})(n_{hc}{\times}d)=n_{hc}{\times}d ( n h c × n h c ) ( n h c × d ) = n h c × d
n h c n_{hc} n h c 条流之间相互混合
层进去前由 A l A_l A l 降回 d d d 、出来后由 C l C_l C l 摊回 n h c n_{hc} n h c 条,注意力与 MoE 始终在 d d d 维上运算。n h c n_{hc} n h c 只是给跳连接结构额外增加的一根轴,且 n h c ≪ d n_{hc}\ll d n h c ≪ d ,开销很小。
Pre-Block Mixing = A l =A_l = A l (输入合成)、Post-Block Mixing = C l =C_l = C l (输出广播)、Residual Mixing = B l =B_l = B l (残差流混合,替代标准残差的恒等跳连)。每个 Transformer 层含两套这样的结构,分别包住注意力子层与 MoE 子层。
流形约束:把 B l B_l B l 限制为双随机矩阵
裸 HC 堆叠多层时数值不稳。mHC 的核心是把残差混合矩阵 B l B_l B l 约束到双随机矩阵 集合(Birkhoff 多面体):
M = { M ∈ R n × n ∣ M 1 n = 1 n , 1 n ⊤ M = 1 n ⊤ , M ≥ 0 } \mathcal{M}=\{M\in\mathbb{R}^{n\times n}\mid M\mathbf{1}_n=\mathbf{1}_n,\ \mathbf{1}_n^\top M=\mathbf{1}_n^\top,\ M\ge 0\}
M = { M ∈ R n × n ∣ M 1 n = 1 n , 1 n ⊤ M = 1 n ⊤ , M ≥ 0 }
其中 1 n \mathbf 1_n 1 n 是全 1 向量、M ≥ 0 M\ge 0 M ≥ 0 指逐元素非负;该集合即"每行和、每列和均为 1 且元素非负"的矩阵。约束通过一个可微投影 (Sinkhorn–Knopp 迭代)施加,具体过程见下一小节;A l , C l A_l,C_l A l , C l 则用 Sigmoid 约束为非负且有界(A l = σ ( A ~ l ) A_l=\sigma(\tilde A_l) A l = σ ( A ~ l ) ,C l = 2 σ ( C ~ l ) C_l=2\sigma(\tilde C_l) C l = 2 σ ( C ~ l ) )。三个映射的参数都由“输入相关的动态分量 + 静态偏置”动态生成。
类比:softmax 把任意向量投影到概率单纯形(非负、和为 1);Sinkhorn-Knopp 是它在矩阵上的二维推广(行、列同时归一化)。
双随机投影的迭代过程(Sinkhorn–Knopp)
约束前的原始矩阵 B ~ l \tilde B_l B ~ l 由动态参数化生成——残差状态拉平并归一化:
X ^ l = R M S N o r m ( v e c ( X l ) ) ∈ R 1 × n h c d \hat X_l=\mathrm{RMSNorm}(\mathrm{vec}(X_l))\in\mathbb{R}^{1\times n_{hc}d}
X ^ l = R M S N o r m ( v e c ( X l ) ) ∈ R 1 × n h c d
再线性投影、reshape 成方阵,加门控与静态偏置:
B ~ l = α l r e s ⋅ M a t ( X ^ l W l r e s ) + S l r e s ∈ R n h c × n h c \tilde B_l=\alpha_l^{\mathrm{res}}\cdot\mathrm{Mat}\big(\hat X_l\,W_l^{\mathrm{res}}\big)+S_l^{\mathrm{res}}\in\mathbb{R}^{n_{hc}\times n_{hc}}
B ~ l = α l r e s ⋅ M a t ( X ^ l W l r e s ) + S l r e s ∈ R n h c × n h c
其中 W l r e s ∈ R n h c d × n h c 2 W_l^{\mathrm{res}}\in\mathbb{R}^{n_{hc}d\times n_{hc}^2} W l r e s ∈ R n h c d × n h c 2 ,M a t ( ⋅ ) \mathrm{Mat}(\cdot) M a t ( ⋅ ) 把 1 × n h c 2 1\times n_{hc}^2 1 × n h c 2 向量重排成 n h c × n h c n_{hc}\times n_{hc} n h c × n h c 矩阵,α l r e s \alpha_l^{\mathrm{res}} α l r e s 是初始化为小值的门控标量,S l r e s S_l^{\mathrm{res}} S l r e s 是静态偏置。
下面记 n = n h c n=n_{hc} n = n h c 。先逐元素取指数以保正:
M ( 0 ) = exp ( B ~ l ) M^{(0)}=\exp(\tilde B_l)
M ( 0 ) = exp ( B ~ l )
定义行归一化算子(每行除以行和):
[ T r ( M ) ] i j = M i j ∑ j ′ = 1 n M i j ′ [\mathcal{T}_r(M)]_{ij}=\frac{M_{ij}}{\sum_{j'=1}^{n}M_{ij'}}
[ T r ( M ) ] i j = ∑ j ′ = 1 n M i j ′ M i j
定义列归一化算子(每列除以列和):
[ T c ( M ) ] i j = M i j ∑ i ′ = 1 n M i ′ j [\mathcal{T}_c(M)]_{ij}=\frac{M_{ij}}{\sum_{i'=1}^{n}M_{i'j}}
[ T c ( M ) ] i j = ∑ i ′ = 1 n M i ′ j M i j
交替迭代(先列后行):
M ( t ) = T r ( T c ( M ( t − 1 ) ) ) , t = 1 , … , t max M^{(t)}=\mathcal{T}_r\big(\mathcal{T}_c(M^{(t-1)})\big),\quad t=1,\dots,t_{\max}
M ( t ) = T r ( T c ( M ( t − 1 ) ) ) , t = 1 , … , t m a x
取末次迭代为结果(t max = 20 t_{\max}=20 t m a x = 2 0 ):
B l = M ( t max ) B_l=M^{(t_{\max})}
B l = M ( t m a x )
每轮最后一步是行归一化 T r \mathcal{T}_r T r ,故终止时行和精确为 1、列和近似为 1 ;迭代足够多时两者都趋于 1。
等价的对角缩放形式(Sinkhorn–Knopp 定理)。 对正矩阵 M ( 0 ) M^{(0)} M ( 0 ) 存在唯一(至多差一标量)的正对角 D 1 = d i a g ( r ) D_1=\mathrm{diag}(r) D 1 = d i a g ( r ) 、D 2 = d i a g ( c ) D_2=\mathrm{diag}(c) D 2 = d i a g ( c ) ,使缩放后双随机:
B l = D 1 M ( 0 ) D 2 B_l=D_1\,M^{(0)}\,D_2
B l = D 1 M ( 0 ) D 2
交替归一化即隐式求解这两组因子,更新式为:
r i ← 1 ∑ j M i j ( 0 ) c j r_i\leftarrow\frac{1}{\sum_{j}M^{(0)}_{ij}\,c_j}
r i ← ∑ j M i j ( 0 ) c j 1
c j ← 1 ∑ i M i j ( 0 ) r i c_j\leftarrow\frac{1}{\sum_{i}M^{(0)}_{ij}\,r_i}
c j ← ∑ i M i j ( 0 ) r i 1
输入 / 输出映射 A l , C l A_l, C_l A l , C l 的生成、取值与系数 2(推断)
A l , C l A_l,C_l A l , C l 与 B l B_l B l 一样动态生成 :先取归一化的残差状态 X ^ l = R M S N o r m ( v e c ( X l ) ) ∈ R 1 × n h c d \hat X_l=\mathrm{RMSNorm}(\mathrm{vec}(X_l))\in\mathbb{R}^{1\times n_{hc}d} X ^ l = R M S N o r m ( v e c ( X l ) ) ∈ R 1 × n h c d (见上一小节),再各自投影出动态分量、加静态偏置,得到约束前的原始量:
A ~ l = α l p r e ( X ^ l W l p r e ) + S l p r e \tilde A_l=\alpha_l^{\mathrm{pre}}\,(\hat X_l\,W_l^{\mathrm{pre}})+S_l^{\mathrm{pre}}
A ~ l = α l p r e ( X ^ l W l p r e ) + S l p r e
C ~ l = α l p o s t ( X ^ l W l p o s t ) ⊤ + S l p o s t \tilde C_l=\alpha_l^{\mathrm{post}}\,(\hat X_l\,W_l^{\mathrm{post}})^\top+S_l^{\mathrm{post}}
C ~ l = α l p o s t ( X ^ l W l p o s t ) ⊤ + S l p o s t
再过 Sigmoid 得最终映射(保证非负、有界):
A l = σ ( A ~ l ) ∈ ( 0 , 1 ) 1 × n h c A_l=\sigma(\tilde A_l)\in(0,1)^{1\times n_{hc}}
A l = σ ( A ~ l ) ∈ ( 0 , 1 ) 1 × n h c
C l = 2 σ ( C ~ l ) ∈ ( 0 , 2 ) n h c × 1 C_l=2\sigma(\tilde C_l)\in(0,2)^{n_{hc}\times 1}
C l = 2 σ ( C ~ l ) ∈ ( 0 , 2 ) n h c × 1
参数含义:
X ^ l \hat X_l X ^ l :残差状态拉平并归一化后的向量,使映射随输入变化 (即动态分量的来源);
W l p r e , W l p o s t ∈ R n h c d × n h c W_l^{\mathrm{pre}},W_l^{\mathrm{post}}\in\mathbb{R}^{n_{hc}d\times n_{hc}} W l p r e , W l p o s t ∈ R n h c d × n h c :可学习投影矩阵,产生 A l / C l A_l/C_l A l / C l 的输入相关动态分量 ;
S l p r e ∈ R 1 × n h c S_l^{\mathrm{pre}}\in\mathbb{R}^{1\times n_{hc}} S l p r e ∈ R 1 × n h c 、S l p o s t ∈ R n h c × 1 S_l^{\mathrm{post}}\in\mathbb{R}^{n_{hc}\times 1} S l p o s t ∈ R n h c × 1 :静态偏置 (与输入无关),给出映射的基准值;
α l p r e , α l p o s t ∈ R \alpha_l^{\mathrm{pre}},\alpha_l^{\mathrm{post}}\in\mathbb{R} α l p r e , α l p o s t ∈ R :可学习的标量门控 ,初始化为小值 。
初始化时的行为: α \alpha α 初始为小值 ⇒ \Rightarrow ⇒ 训练初期动态分量 ≈ 0 \approx 0 ≈ 0 ,于是 A ~ l ≈ S l p r e \tilde A_l\approx S_l^{\mathrm{pre}} A ~ l ≈ S l p r e 、C ~ l ≈ S l p o s t \tilde C_l\approx S_l^{\mathrm{post}} C ~ l ≈ S l p o s t ,映射主要由静态偏置决定;当静态偏置也接近 0 时,A l ≈ σ ( 0 ) = 0.5 A_l\approx\sigma(0)=0.5 A l ≈ σ ( 0 ) = 0 . 5 、C l ≈ 2 σ ( 0 ) = 1 C_l\approx 2\sigma(0)=1 C l ≈ 2 σ ( 0 ) = 1 ——这正是下表"中性点"一栏,也对应下面"系数 2"讨论的初始情形。
下表汇总两个映射的取值范围与"中性点"(原始参数 ≈ 0 \approx 0 ≈ 0 时):
映射
公式
范围
中性点
输入 A l A_l A l
σ ( A ~ l ) \sigma(\tilde A_l) σ ( A ~ l )
( 0 , 1 ) (0,1) ( 0 , 1 )
σ ( 0 ) = 0.5 \sigma(0)=0.5 σ ( 0 ) = 0 . 5
输出 C l C_l C l
2 σ ( C ~ l ) 2\sigma(\tilde C_l) 2 σ ( C ~ l )
( 0 , 2 ) (0,2) ( 0 , 2 )
2 σ ( 0 ) = 1 2\sigma(0)=1 2 σ ( 0 ) = 1
论文只说明 Sigmoid 用于保证非负与有界(避免信号抵消),并未解释 C l C_l C l 为何额外乘 2。以下为结合超连接初始化惯例的推断 :
C l C_l C l 决定层输出写回各残差流的权重,而标准残差中这个权重恰为 1。系数 2 使 C l C_l C l 的中性点正好落在 1(2 σ ( 0 ) = 1 2\sigma(0)=1 2 σ ( 0 ) = 1 ),于是初始化时每层即以单位权重贡献输出,并留出到 2 的放大余量。若改用不带系数的 σ \sigma σ ,输出权重初始仅 0.5、且永远到不了 1。输入映射 A l A_l A l 是对 n h c n_{hc} n h c 条流的加权读取,其结果下游还会经 RMSNorm 归一化、绝对尺度不关键,故落在 ( 0 , 1 ) (0,1) ( 0 , 1 ) 、中性点 0.5 即可,无需该系数。
双随机 ⇒ \Rightarrow ⇒ 谱范数 = 1 =1 = 1 ,但并非所有奇异值都为 1
双随机矩阵的谱范数(最大奇异值)恒为 1。
记号:下面 ∥ B x ∥ 2 , ∥ x ∥ 2 \lVert Bx\rVert_2,\lVert x\rVert_2 ∥ B x ∥ 2 , ∥ x ∥ 2 是向量 的欧氏范数 ∑ k ( ⋅ ) k 2 \sqrt{\sum_k(\cdot)_k^2} ∑ k ( ⋅ ) k 2 ;∥ B ∥ 2 \lVert B\rVert_2 ∥ B ∥ 2 是矩阵 的谱范数(最大奇异值)。二者由定义 ∥ B ∥ 2 = max x ≠ 0 ∥ B x ∥ 2 / ∥ x ∥ 2 \lVert B\rVert_2=\max_{x\ne0}\lVert Bx\rVert_2/\lVert x\rVert_2 ∥ B ∥ 2 = max x = 0 ∥ B x ∥ 2 / ∥ x ∥ 2 相连——同一个下标 2,作用在向量上是长度、在矩阵上是最大奇异值。
先看逐行 的关键一步:每一行 i i i 非负且和为 1,是一个概率分布,于是 ( B x ) i = ∑ j B i j x j (Bx)_i=\sum_j B_{ij}x_j ( B x ) i = ∑ j B i j x j 是 { x j } \{x_j\} { x j } 在该分布下的均值 。由"均值的平方 ≤ \le ≤ 平方的均值 "(设随机变量 X X X 以概率 B i j B_{ij} B i j 取值 x j x_j x j ,则这等价于方差非负 E [ X 2 ] − ( E [ X ] ) 2 ≥ 0 \mathbb E[X^2]-(\mathbb E[X])^2\ge 0 E [ X 2 ] − ( E [ X ] ) 2 ≥ 0 ,也即 t 2 t^2 t 2 的 Jensen 不等式):
( ∑ j B i j x j ) 2 ≤ ∑ j B i j x j 2 \Big(\sum_j B_{ij}x_j\Big)^2\le\sum_j B_{ij}x_j^2
( j ∑ B i j x j ) 2 ≤ j ∑ B i j x j 2
对所有行 i i i 相加(左边即 ∥ B x ∥ 2 2 \lVert Bx\rVert_2^2 ∥ B x ∥ 2 2 ),再用列和为 1 把右边塌缩:
∥ B x ∥ 2 2 = ∑ i ( ∑ j B i j x j ) 2 ≤ ∑ i ∑ j B i j x j 2 = ∑ j x j 2 ∑ i B i j ⏟ = 1 = ∥ x ∥ 2 2 \lVert Bx\rVert_2^2=\sum_i\Big(\sum_j B_{ij}x_j\Big)^2\le\sum_i\sum_j B_{ij}x_j^2=\sum_j x_j^2\underbrace{\sum_i B_{ij}}_{=1}=\lVert x\rVert_2^2
∥ B x ∥ 2 2 = i ∑ ( j ∑ B i j x j ) 2 ≤ i ∑ j ∑ B i j x j 2 = j ∑ x j 2 = 1 i ∑ B i j = ∥ x ∥ 2 2
上式对任意 向量 x x x 成立,即 ∥ B x ∥ 2 ≤ ∥ x ∥ 2 \lVert Bx\rVert_2\le\lVert x\rVert_2 ∥ B x ∥ 2 ≤ ∥ x ∥ 2 。下面用夹逼得到谱范数恰为 1。
上界。 由谱范数定义:
∥ B ∥ 2 = max x ≠ 0 ∥ B x ∥ 2 ∥ x ∥ 2 ≤ 1 \lVert B\rVert_2=\max_{x\ne0}\frac{\lVert Bx\rVert_2}{\lVert x\rVert_2}\le 1
∥ B ∥ 2 = x = 0 max ∥ x ∥ 2 ∥ B x ∥ 2 ≤ 1
下界。 谱范数是对所有方向取 max,故代入任一具体向量都给出下界 。取 x = 1 x=\mathbf 1 x = 1 (全 1 向量);因每行和为 1,( B 1 ) i = ∑ j B i j = 1 (B\mathbf 1)_i=\sum_j B_{ij}=1 ( B 1 ) i = ∑ j B i j = 1 ,即 B 1 = 1 B\mathbf 1=\mathbf 1 B 1 = 1 ,于是
∥ B ∥ 2 ≥ ∥ B 1 ∥ 2 ∥ 1 ∥ 2 = ∥ 1 ∥ 2 ∥ 1 ∥ 2 = 1 \lVert B\rVert_2\ge\frac{\lVert B\mathbf 1\rVert_2}{\lVert\mathbf 1\rVert_2}=\frac{\lVert\mathbf 1\rVert_2}{\lVert\mathbf 1\rVert_2}=1
∥ B ∥ 2 ≥ ∥ 1 ∥ 2 ∥ B 1 ∥ 2 = ∥ 1 ∥ 2 ∥ 1 ∥ 2 = 1
1 ≤ ∥ B ∥ 2 ≤ 1 1\le\lVert B\rVert_2\le 1 1 ≤ ∥ B ∥ 2 ≤ 1 ,故 ∥ B ∥ 2 = 1 \lVert B\rVert_2=1 ∥ B ∥ 2 = 1 。
(上界的不等号用"行和 = 1 ⇒ 每行是分布、方差非负"、塌缩用"列和 = 1";下界只用到 B 1 = 1 B\mathbf 1=\mathbf 1 B 1 = 1 ,即行和 = 1。)
但其余奇异值一般 < 1 <1 < 1 。例如 ( 0.7 0.3 0.3 0.7 ) \left(\begin{smallmatrix}0.7&0.3\\0.3&0.7\end{smallmatrix}\right) ( 0 . 7 0 . 3 0 . 3 0 . 7 ) 的奇异值为 { 1 , 0.4 } \{1,0.4\} { 1 , 0 . 4 } 。因此 B l B_l B l 是非扩张但通常不正交 的:沿 1 \mathbf{1} 1 方向增益恰为 1(保长),其余方向被收缩。只有置换矩阵(含单位阵,即标准残差)这类 Birkhoff 顶点才全为 1。这一点与 Muon 把全部奇异值拉成 1 的"完全正交"形成对照(见 2.4 与 Muon 笔记 )。
多层传递为何既不爆炸也不消失
把式 (1) 的 block 贡献记作 D l = C l F l ( A l X l ) D_l=C_l\mathcal{F}_l(A_lX_l) D l = C l F l ( A l X l ) ,展开 L L L 层:
X L = ( ∏ l B l ) X 0 + ∑ l ( ∏ j > l B j ) D l X_L=\Big(\textstyle\prod_{l} B_l\Big)X_0+\sum_{l}\Big(\textstyle\prod_{j>l}B_j\Big)D_l
X L = ( ∏ l B l ) X 0 + ∑ l ( ∏ j > l B j ) D l
爆炸风险来自 B B B 的连乘。若 ∥ B l ∥ 2 > 1 \lVert B_l\rVert_2>1 ∥ B l ∥ 2 > 1 (哪怕 1.1),连乘约 ρ L \rho^L ρ L ,指数爆炸;而双随机给出 ∥ B l ∥ 2 = 1 \lVert B_l\rVert_2=1 ∥ B l ∥ 2 = 1 、连乘仍 ≤ 1 \le 1 ≤ 1 ,于是 ∥ X L ∥ ≤ ∥ X 0 ∥ + ∑ l ∥ D l ∥ \lVert X_L\rVert\le\lVert X_0\rVert+\sum_l\lVert D_l\rVert ∥ X L ∥ ≤ ∥ X 0 ∥ + ∑ l ∥ D l ∥ ,至多线性增长 。
消失风险则被列和守恒直接排除。对式 (1) 左乘 1 ⊤ \mathbf{1}^\top 1 ⊤ (即对 n h c n_{hc} n h c 条流求和),因 1 ⊤ B l = 1 ⊤ \mathbf{1}^\top B_l=\mathbf{1}^\top 1 ⊤ B l = 1 ⊤ :
s l + 1 : = 1 ⊤ X l + 1 = 1 ⊤ X l + 1 ⊤ D l = s l + 1 ⊤ D l s_{l+1}:=\mathbf{1}^\top X_{l+1}=\mathbf{1}^\top X_l+\mathbf{1}^\top D_l=s_l+\mathbf{1}^\top D_l
s l + 1 : = 1 ⊤ X l + 1 = 1 ⊤ X l + 1 ⊤ D l = s l + 1 ⊤ D l
跨流之和是一个纯加法累加器 ,B l B_l B l 的衰减在求和中精确抵消——既不指数放大,也不衰减归零。而末端正是把 n h c n_{hc} n h c 条流聚合后读出,读的就是这个被守恒的量。被收缩的只是与 1 \mathbf{1} 1 正交的"流间差异",且每层有新内容续上。反向传播乘 B l ⊤ B_l^\top B l ⊤ (仍双随机) ,梯度同理不爆不消。
2.3 混合注意力:CSA 与 HCA
百万级上下文下注意力是主要算力 / 显存瓶颈。V4 用两种压缩注意力逐层交替 :
CSA (Compressed Sparse Attention):中等压缩(每 m m m 个 token 合 1 个条目)+ 稀疏(每个 query 只看 top-k k k 个压缩块);
HCA (Heavily Compressed Attention):重度压缩(每 m ′ ≫ m m'\gg m m ′ ≫ m 个 token 合 1 个)+ 稠密(看全部压缩块)。
两者核心注意力都用共享 KV 的 MQA ,都配滑动窗口分支、部分 RoPE 与 RMSNorm。符号与超参(取值列 V4-Pro / V4-Flash):
符号
含义
取值
n , d n,\ d n , d
序列长度、隐藏维
— / 7168、4096
c c c
KV / 注意力头维
512
m , m ′ m,\ m' m , m ′
CSA / HCA 压缩率
4 / 128
n h n_h n h
注意力 query 头数
128 / 64
d c d_c d c
query 低秩(潜)维
1536 / 1024
c I , n h I c^I,\ n_h^I c I , n h I
索引器头维、头数
128、64
k k k
top-k k k 选块数(仅 CSA)
1024 / 512
g , d g g,\ d_g g , d g
输出分组数、组内中间维
16 / 8,1024
n w i n n_{win} n w i n
滑动窗口长度
128
r r r
施加 RoPE 的维数(末段)
64
记本注意力层输入 H = [ h 1 ; … ; h n ] ∈ R n × d H=[h_1;\dots;h_n]\in\mathbb{R}^{n\times d} H = [ h 1 ; … ; h n ] ∈ R n × d ,h t ∈ R d h_t\in\mathbb{R}^d h t ∈ R d 为第 t t t 个 token 的隐状态。以下以 CSA 为主线,HCA 差异见后。
CSA·第一步:KV 压缩
两组线性投影(W a K V , W b K V , W a Z , W b Z ∈ R d × c W^{aKV},W^{bKV},W^{aZ},W^{bZ}\in\mathbb{R}^{d\times c} W a K V , W b K V , W a Z , W b Z ∈ R d × c )得到两路 KV 条目与各自的压缩打分:
C a = H W a K V , C b = H W b K V ∈ R n × c C^a=HW^{aKV},\qquad C^b=HW^{bKV}\ \in\mathbb{R}^{n\times c}
C a = H W a K V , C b = H W b K V ∈ R n × c
Z a = H W a Z , Z b = H W b Z ∈ R n × c Z^a=HW^{aZ},\qquad Z^b=HW^{bZ}\ \in\mathbb{R}^{n\times c}
Z a = H W a Z , Z b = H W b Z ∈ R n × c
每 m m m 个相邻 token 合成 1 个条目。第 i i i 个条目取材自"本块 C a C^a C a 的 m m m 行 + 前一块 C b C^b C b 的 m m m 行"共 2 m 2m 2 m 行;对这 2 m 2m 2 m 个打分(加可学习位置偏置 B a , B b ∈ R m × c B^a,B^b\in\mathbb{R}^{m\times c} B a , B b ∈ R m × c )做逐通道 softmax 得权重:
[ S m i : m ( i + 1 ) − 1 a ; S m ( i − 1 ) : m i − 1 b ] = S o f t m a x row ( [ Z m i : m ( i + 1 ) − 1 a + B a ; Z m ( i − 1 ) : m i − 1 b + B b ] ) \big[S^a_{mi:m(i+1)-1};\ S^b_{m(i-1):mi-1}\big]=\mathrm{Softmax}_{\text{row}}\!\big(\big[Z^a_{mi:m(i+1)-1}+B^a;\ Z^b_{m(i-1):mi-1}+B^b\big]\big)
[ S m i : m ( i + 1 ) − 1 a ; S m ( i − 1 ) : m i − 1 b ] = S o f t m a x row ( [ Z m i : m ( i + 1 ) − 1 a + B a ; Z m ( i − 1 ) : m i − 1 b + B b ] )
再加权求和(⊙ \odot ⊙ 逐元素乘):
C i Comp = ∑ j = m i m ( i + 1 ) − 1 S j a ⊙ C j a + ∑ j = m ( i − 1 ) m i − 1 S j b ⊙ C j b ∈ R c C_i^{\text{Comp}}=\sum_{j=mi}^{m(i+1)-1}S_j^a\odot C_j^a+\sum_{j=m(i-1)}^{mi-1}S_j^b\odot C_j^b\ \in\mathbb{R}^c
C i Comp = j = m i ∑ m ( i + 1 ) − 1 S j a ⊙ C j a + j = m ( i − 1 ) ∑ m i − 1 S j b ⊙ C j b ∈ R c
(i = 0 i=0 i = 0 时前块不存在:Z b Z^b Z b 块补 − ∞ -\infty − ∞ 、C b C^b C b 块补 0 0 0 。)相邻条目取材在块边界重叠,故序列压到 1 / m 1/m 1 / m :C Comp ∈ R n m × c C^{\text{Comp}}\in\mathbb{R}^{\frac{n}{m}\times c} C Comp ∈ R m n × c 。
CSA·第二步:Lightning Indexer 选 top-k k k
索引器用低维 c I c^I c I 给"query–压缩块"打分,只负责选块、不参与核心注意力(可 FP4 计算)。
索引器键 :用与上面相同的压缩流程、但把 KV 投到 c I c^I c I 维,得 K IComp ∈ R n m × c I K^{\text{IComp}}\in\mathbb{R}^{\frac{n}{m}\times c^I} K IComp ∈ R m n × c I ,第 s s s 块记 K s IComp ∈ R c I K_s^{\text{IComp}}\in\mathbb{R}^{c^I} K s IComp ∈ R c I 。
索引器 query (低秩;W D Q ∈ R d × d c W^{DQ}\in\mathbb{R}^{d\times d_c} W D Q ∈ R d × d c ,W I U Q ∈ R d c × c I n h I W^{IUQ}\in\mathbb{R}^{d_c\times c^I n_h^I} W I U Q ∈ R d c × c I n h I ):
c t Q = h t W D Q ∈ R d c \mathbf c_t^Q=h_tW^{DQ}\ \in\mathbb{R}^{d_c}
c t Q = h t W D Q ∈ R d c
q t I = c t Q W I U Q = [ q t , 1 I ; … ; q t , n h I I ] , q t , h I ∈ R c I q_t^I=\mathbf c_t^Q W^{IUQ}=[q^I_{t,1};\dots;q^I_{t,n_h^I}],\qquad q^I_{t,h}\in\mathbb{R}^{c^I}
q t I = c t Q W I U Q = [ q t , 1 I ; … ; q t , n h I I ] , q t , h I ∈ R c I
每头打分权重 (W w ∈ R d × n h I W^w\in\mathbb{R}^{d\times n_h^I} W w ∈ R d × n h I ):
w t I = h t W w = [ w t , 1 I ; … ; w t , n h I I ] , w t , h I ∈ R w_t^I=h_tW^w=[w^I_{t,1};\dots;w^I_{t,n_h^I}],\qquad w^I_{t,h}\in\mathbb{R}
w t I = h t W w = [ w t , 1 I ; … ; w t , n h I I ] , w t , h I ∈ R
索引分数 (h h h 遍历 n h I n_h^I n h I 个索引头):
I t , s = ∑ h = 1 n h I w t , h I R e L U ( q t , h I ⋅ K s IComp ) ∈ R I_{t,s}=\sum_{h=1}^{n_h^I}w^I_{t,h}\,\mathrm{ReLU}\!\big(q^I_{t,h}\cdot K_s^{\text{IComp}}\big)\ \in\mathbb{R}
I t , s = h = 1 ∑ n h I w t , h I R e L U ( q t , h I ⋅ K s IComp ) ∈ R
top-k k k 选块 (因果:只在前序完整块 s < ⌊ t / m ⌋ s<\lfloor t/m\rfloor s < ⌊ t / m ⌋ 上选最高的 k k k 个):
S t = T o p k ( { I t , s } s < ⌊ t / m ⌋ , k ) , C t SprsComp = { C s Comp : s ∈ S t } ∈ R k × c \mathcal S_t=\mathrm{Topk}\big(\{I_{t,s}\}_{\,s<\lfloor t/m\rfloor},\ k\big),\qquad C_t^{\text{SprsComp}}=\{C_s^{\text{Comp}}:s\in\mathcal S_t\}\in\mathbb{R}^{k\times c}
S t = T o p k ( { I t , s } s < ⌊ t / m ⌋ , k ) , C t SprsComp = { C s Comp : s ∈ S t } ∈ R k × c
CSA·第三步:共享 KV 的 MQA
注意力 query 从同一个 潜向量 c t Q \mathbf c_t^Q c t Q 升维(与索引器共享 c t Q \mathbf c_t^Q c t Q ;W U Q ∈ R d c × c n h W^{UQ}\in\mathbb{R}^{d_c\times c n_h} W U Q ∈ R d c × c n h ):
q t = c t Q W U Q = [ q t , 1 ; … ; q t , n h ] , q t , i ∈ R c q_t=\mathbf c_t^Q W^{UQ}=[q_{t,1};\dots;q_{t,n_h}],\qquad q_{t,i}\in\mathbb{R}^{c}
q t = c t Q W U Q = [ q t , 1 ; … ; q t , n h ] , q t , i ∈ R c
MQA:n h n_h n h 个 query 头共用同一份 选中条目(每条目既当 key 又当 value):
o t , i = C o r e A t t n ( q t , i , C t SprsComp , C t SprsComp ) ∈ R c , i = 1 , … , n h o_{t,i}=\mathrm{CoreAttn}\big(q_{t,i},\,C_t^{\text{SprsComp}},\,C_t^{\text{SprsComp}}\big)\in\mathbb{R}^{c},\qquad i=1,\dots,n_h
o t , i = C o r e A t t n ( q t , i , C t SprsComp , C t SprsComp ) ∈ R c , i = 1 , … , n h
C o r e A t t n ( q , K , V ) = S o f t m a x ( q K ⊤ c ) V ( 含因果掩码与下文 attention sink ) \mathrm{CoreAttn}(q,K,V)=\mathrm{Softmax}\!\Big(\tfrac{qK^\top}{\sqrt{c}}\Big)V\quad(\text{含因果掩码与下文 attention sink})
C o r e A t t n ( q , K , V ) = S o f t m a x ( c q K ⊤ ) V ( 含因果掩码与下文 attention sink )
分组输出投影(CSA / HCA 共用)
多头输出拼接 o t = [ o t , 1 ; … ; o t , n h ] ∈ R c n h o_t=[o_{t,1};\dots;o_{t,n_h}]\in\mathbb{R}^{c n_h} o t = [ o t , 1 ; … ; o t , n h ] ∈ R c n h 维度很大(Pro 为 512 × 128 = 65536 512\times128=65536 5 1 2 × 1 2 8 = 6 5 5 3 6 ),直接投回 d d d 维代价高。改两段式:n h n_h n h 头分 g g g 组,第 i i i 组拼为 o t , i G ∈ R c n h / g o_{t,i}^{G}\in\mathbb{R}^{c\,n_h/g} o t , i G ∈ R c n h / g ,先各自投到 d g d_g d g 维(d g < c n h / g d_g<c\,n_h/g d g < c n h / g ),再拼接投到 d d d 维(投影矩阵 W i G ∈ R c n h / g × d g W_i^{G}\in\mathbb{R}^{c n_h/g\times d_g} W i G ∈ R c n h / g × d g 、W O ∈ R d g g × d W^{O}\in\mathbb{R}^{d_g g\times d} W O ∈ R d g g × d ,均可学习):
o t , i G ′ = o t , i G W i G ∈ R d g ( i = 1 , … , g ) o_{t,i}^{G'}=o_{t,i}^{G}\,W_i^{G}\in\mathbb{R}^{d_g}\quad(i=1,\dots,g)
o t , i G ′ = o t , i G W i G ∈ R d g ( i = 1 , … , g )
o ^ t = [ o t , 1 G ′ ; … ; o t , g G ′ ] W O ∈ R d \hat o_t=\big[o_{t,1}^{G'};\dots;o_{t,g}^{G'}\big]\,W^{O}\in\mathbb{R}^{d}
o ^ t = [ o t , 1 G ′ ; … ; o t , g G ′ ] W O ∈ R d
其中 o t , i G ′ o_{t,i}^{G'} o t , i G ′ 即"中间输出",是两段式投影的内在一步(HCA 同)。
其他细节:RMSNorm、部分 RoPE、滑动窗口、Attention Sink
Q/KV RMSNorm。 核心注意力前,对每个 query 头、以及压缩 KV 条目(单头)各做一次 RMSNorm,避免注意力 logit 爆炸。
部分 RoPE。 只对 query / KV / 输出向量的末 r = 64 r=64 r = 6 4 维 施加 RoPE。因共享 KV(条目既是 key 又是 value),给它转 RoPE 当 key 时 value 也被转。记 a t , s a_{t,s} a t , s 为注意力权重、e s e_s e s 为第 s s s 个被选条目的未旋转内容、R p R_p R p 为位置 p p p 的 RoPE 旋转,则
o t = ∑ s a t , s R s e s o_t=\sum_{s}a_{t,s}\,R_s\,e_s
o t = s ∑ a t , s R s e s
按绝对位置 R s R_s R s 残留位置。对策:对输出再施位置 − t -t − t 的 RoPE,把绝对变相对:
R − t o t = ∑ s a t , s R s − t e s R_{-t}\,o_t=\sum_{s}a_{t,s}\,R_{s-t}\,e_s
R − t o t = s ∑ a t , s R s − t e s
每个 query-头只做一次(共 n h n_h n h 次),与命中多少 key 无关——R − t R_{-t} R − t 只依赖 t t t ,可由线性性提到求和外面整体施加一次。
滑动窗口分支。 压缩分支因因果只看完整前序块(s < ⌊ t / m ⌋ s<\lfloor t/m\rfloor s < ⌊ t / m ⌋ ),看不到当前块内 token,且压缩会模糊近处 token 的细节。故每个 query 额外、无条件地保留最近 n w i n = 128 n_{win}=128 n w i n = 1 2 8 个未压缩 KV 条目,与 C t SprsComp C_t^{\text{SprsComp}} C t SprsComp 拼接后一起做 MQA。它定长、随位置滑动,推理按状态缓存管理(开销为常数)。HCA 同样配此分支。
Attention Sink。 每个头加一个可学习 sink logit z h ′ z'_h z h ′ 进 softmax 分母,允许该头注意力权重总和不为 1(甚至接近 0):
a t , s ( h ) = exp ( ℓ t , s ( h ) ) ∑ s ′ exp ( ℓ t , s ′ ( h ) ) + exp ( z h ′ ) a_{t,s}^{(h)}=\frac{\exp(\ell_{t,s}^{(h)})}{\sum_{s'}\exp(\ell_{t,s'}^{(h)})+\exp(z'_h)}
a t , s ( h ) = ∑ s ′ exp ( ℓ t , s ′ ( h ) ) + exp ( z h ′ ) exp ( ℓ t , s ( h ) )
其中 ℓ t , s ( h ) \ell_{t,s}^{(h)} ℓ t , s ( h ) 为第 h h h 头 query t t t 对条目 s s s 的注意力 logit。
HCA:与 CSA 的差异
HCA 压缩率更高且不做稀疏选择 。压缩用单路、不重叠 (W K V , W Z ∈ R d × c W^{KV},W^Z\in\mathbb{R}^{d\times c} W K V , W Z ∈ R d × c ,位置偏置 B ∈ R m ′ × c B\in\mathbb{R}^{m'\times c} B ∈ R m ′ × c ):
C = H W K V , Z = H W Z ∈ R n × c C=HW^{KV},\qquad Z=HW^Z\ \in\mathbb{R}^{n\times c}
C = H W K V , Z = H W Z ∈ R n × c
S m ′ i : m ′ ( i + 1 ) − 1 = S o f t m a x row ( Z m ′ i : m ′ ( i + 1 ) − 1 + B ) S_{m'i:m'(i+1)-1}=\mathrm{Softmax}_{\text{row}}\!\big(Z_{m'i:m'(i+1)-1}+B\big)
S m ′ i : m ′ ( i + 1 ) − 1 = S o f t m a x row ( Z m ′ i : m ′ ( i + 1 ) − 1 + B )
C i Comp = ∑ j = m ′ i m ′ ( i + 1 ) − 1 S j ⊙ C j ∈ R c , C Comp ∈ R n m ′ × c C_i^{\text{Comp}}=\sum_{j=m'i}^{m'(i+1)-1}S_j\odot C_j\ \in\mathbb{R}^c,\qquad C^{\text{Comp}}\in\mathbb{R}^{\frac{n}{m'}\times c}
C i Comp = j = m ′ i ∑ m ′ ( i + 1 ) − 1 S j ⊙ C j ∈ R c , C Comp ∈ R m ′ n × c
query 同样低秩生成(c t Q = h t W D Q \mathbf c_t^Q=h_tW^{DQ} c t Q = h t W D Q 、q t = c t Q W U Q q_t=\mathbf c_t^Q W^{UQ} q t = c t Q W U Q ),但核心注意力对全部压缩条目稠密 做(无索引器、无 top-k k k ):
o t , i = C o r e A t t n ( q t , i , C Comp , C Comp ) ∈ R c o_{t,i}=\mathrm{CoreAttn}\big(q_{t,i},\,C^{\text{Comp}},\,C^{\text{Comp}}\big)\in\mathbb{R}^c
o t , i = C o r e A t t n ( q t , i , C Comp , C Comp ) ∈ R c
其余(共享 KV MQA、分组输出投影、滑动窗口、RMSNorm、部分 RoPE)与 CSA 一致。差异归纳:
CSA
HCA
压缩率
m = 4 m=4 m = 4
m ′ = 128 m'=128 m ′ = 1 2 8
压缩方式
两路 C a , C b C^a,C^b C a , C b + 重叠
单路 C C C + 不重叠
稀疏选择
Lightning Indexer + top-k k k
无(稠密)
逻辑链:m ′ m' m ′ 极大 ⟹ 压缩后条目数 n / m ′ n/m' n / m ′ 已很少 ⟹ 稠密注意力的计算量也可接受 ⟹ 不需索引器与 top-k k k 。
低秩 query(down → up)为何如此
CSA/HCA 的 query 都先 c t Q = h t W D Q \mathbf c_t^Q=h_tW^{DQ} c t Q = h t W D Q 降到 d c d_c d c 、再 q t = c t Q W U Q q_t=\mathbf c_t^Q W^{UQ} q t = c t Q W U Q 升回 c n h c n_h c n h 维,等价于把大投影 W ∈ R d × c n h W\in\mathbb{R}^{d\times c n_h} W ∈ R d × c n h 做低秩分解。动机:
省参数 :V4-Pro(d = 7168 , c n h = 65536 , d c = 1536 d{=}7168,\ c n_h{=}65536,\ d_c{=}1536 d = 7 1 6 8 , c n h = 6 5 5 3 6 , d c = 1 5 3 6 )直接投影约 470 470 4 7 0 M,低秩 7168 × 1536 + 1536 × 65536 ≈ 112 7168\times1536+1536\times65536\approx112 7 1 6 8 × 1 5 3 6 + 1 5 3 6 × 6 5 5 3 6 ≈ 1 1 2 M,约省 4 倍;
共享潜向量 :c t Q \mathbf c_t^Q c t Q 同时供索引器 query 与注意力 query(降维只算一次),使"选块意图"与"用块意图"同源;
省训练激活 :只缓存小的 c t Q \mathbf c_t^Q c t Q 。
升维不可省:注意力 query 必须与 c c c 维键同维、且要 n h n_h n h 个头。query 不进 KV cache,故此低秩与省推理缓存无关。
不再使用 MLA
V4 用 CSA/HCA 取代了 V2/V3 的 MLA。两者压缩的轴不同:MLA 沿特征维 把每个 token 的 KV 压成低秩潜向量(token 数不变);CSA/HCA 沿序列维 把多个 token 合并成一个条目(token 数减少),再以稀疏只读一部分。后者针对的正是百万上下文里"KV 条目数量随长度线性增长"这一 MLA 无法缓解的瓶颈。V4 仍继承了 MLA 的若干组件:低秩 query、解耦/部分 RoPE、MQA 共享 KV,以及"极致压缩 KV cache"的整体取向。
2.4 Muon 优化器
V4 把主优化器从 V3 的 AdamW 换成 Muon——这是它三项架构创新之一。Muon 对二维权重矩阵把更新矩阵正交化(保留方向、把各方向步长拉齐),等价于在谱范数意义下做最速下降;嵌入、预测头、mHC 静态偏置与门控、所有 RMSNorm 仍用 AdamW。
Muon 解决什么问题:正交化的两种写法
AdamW 把权重矩阵视为互不相关的标量逐元素缩放,丢掉了"这是一个线性映射"的结构。而梯度(及其动量)的能量往往高度集中在少数方向上——少数奇异值很大、其余很小,于是普通梯度下降几乎只沿那几个主方向更新,大量幅度小但同样重要的方向几乎不更新。Muon 的做法是把更新矩阵正交化 :方向全保留,但把各方向步长拉成一样大。
对更新矩阵 M M M (在 V4 里 M = μ M t + G t M=\mu M_t+G_t M = μ M t + G t ,即下文 Hybrid NS 的输入),其奇异值分解写成两种等价形式:
M = U Σ V ⊤ = ∑ i = 1 r σ i u i v i ⊤ , M=U\Sigma V^\top=\sum_{i=1}^{r}\sigma_i\,u_i v_i^\top ,
M = U Σ V ⊤ = i = 1 ∑ r σ i u i v i ⊤ ,
其中 r = r a n k ( M ) r=\mathrm{rank}(M) r = r a n k ( M ) ,u i , v i u_i,v_i u i , v i 为左/右奇异向量(U , V U,V U , V 的第 i i i 列),σ i \sigma_i σ i 为奇异值。左式把"旋转-拉伸-旋转"打包成矩阵;右式把 M M M 摊成 r r r 个秩-1 方向 u i v i ⊤ u_i v_i^\top u i v i ⊤ 的加权和、权重正是 σ i \sigma_i σ i 。正交化就是把所有 σ i \sigma_i σ i 设为 1:
o r t h ( M ) = U V ⊤ = ∑ i = 1 r u i v i ⊤ , \mathrm{orth}(M)=UV^\top=\sum_{i=1}^{r}u_i v_i^\top ,
o r t h ( M ) = U V ⊤ = i = 1 ∑ r u i v i ⊤ ,
方向 u i , v i u_i,v_i u i , v i 不动、强度一律拉平到 1。可证这恰是谱范数意义下的最速下降方向(完整论证见专篇 Muon 优化器 )。
Hybrid Newton–Schulz:计算推导
直接对每个权重做 SVD 太慢、且 GPU 低精度下不稳。Muon 改用只含矩阵乘的迭代近似 U V ⊤ UV^\top U V ⊤ 。先归一化 M 0 = M / ∥ M ∥ F M_0=M/\|M\|_F M 0 = M / ∥ M ∥ F (保证最大奇异值 ≤ 1 \le1 ≤ 1 ),再迭代:
M k = a M k − 1 + b ( M k − 1 M k − 1 ⊤ ) M k − 1 + c ( M k − 1 M k − 1 ⊤ ) 2 M k − 1 . M_k=a M_{k-1}+b\,(M_{k-1}M_{k-1}^\top)M_{k-1}+c\,(M_{k-1}M_{k-1}^\top)^2 M_{k-1}.
M k = a M k − 1 + b ( M k − 1 M k − 1 ⊤ ) M k − 1 + c ( M k − 1 M k − 1 ⊤ ) 2 M k − 1 .
为什么它等价于"对奇异值套多项式"。 设 M k − 1 = U Σ V ⊤ M_{k-1}=U\Sigma V^\top M k − 1 = U Σ V ⊤ ,利用 U ⊤ U = V ⊤ V = I U^\top U=V^\top V=I U ⊤ U = V ⊤ V = I :
M k − 1 M k − 1 ⊤ = U Σ 2 U ⊤ , ( M k − 1 M k − 1 ⊤ ) M k − 1 = U Σ 3 V ⊤ , ( M k − 1 M k − 1 ⊤ ) 2 M k − 1 = U Σ 5 V ⊤ , M_{k-1}M_{k-1}^\top=U\Sigma^2U^\top,\quad (M_{k-1}M_{k-1}^\top)M_{k-1}=U\Sigma^3V^\top,\quad (M_{k-1}M_{k-1}^\top)^2M_{k-1}=U\Sigma^5V^\top,
M k − 1 M k − 1 ⊤ = U Σ 2 U ⊤ , ( M k − 1 M k − 1 ⊤ ) M k − 1 = U Σ 3 V ⊤ , ( M k − 1 M k − 1 ⊤ ) 2 M k − 1 = U Σ 5 V ⊤ ,
⇒ M k = U ( a Σ + b Σ 3 + c Σ 5 ) V ⊤ . \Rightarrow\quad M_k=U\big(a\Sigma+b\Sigma^3+c\Sigma^5\big)V^\top .
⇒ M k = U ( a Σ + b Σ 3 + c Σ 5 ) V ⊤ .
U , V U,V U , V 始终不变,迭代只把每个奇异值按 p ( σ ) = a σ + b σ 3 + c σ 5 p(\sigma)=a\sigma+b\sigma^3+c\sigma^5 p ( σ ) = a σ + b σ 3 + c σ 5 更新。只要 p p p 能把 ( 0 , 1 ] (0,1] ( 0 , 1 ] 内的 σ \sigma σ 推向 1,迭代就把 M M M 推向 U V ⊤ UV^\top U V ⊤ 。(p p p 只含奇数次幂,是因为每乘一次 M M ⊤ M M^\top M M ⊤ 就给 Σ \Sigma Σ 加两次幂。)
两段式系数(共 10 步)。
前 8 步用 ( a , b , c ) = ( 3.4445 , − 4.7750 , 2.0315 ) (a,b,c)=(3.4445,-4.7750,2.0315) ( a , b , c ) = ( 3 . 4 4 4 5 , − 4 . 7 7 5 0 , 2 . 0 3 1 5 ) :近零增益 p ′ ( 0 ) = a ≈ 3.44 p'(0)=a\approx3.44 p ′ ( 0 ) = a ≈ 3 . 4 4 很大,把很小的奇异值快速抬起,几步内让所有 σ \sigma σ 聚到 1 附近。但 p ( 1 ) = 3.4445 − 4.7750 + 2.0315 ≈ 0.70 ≠ 1 p(1)=3.4445-4.7750+2.0315\approx0.70\neq1 p ( 1 ) = 3 . 4 4 4 5 − 4 . 7 7 5 0 + 2 . 0 3 1 5 ≈ 0 . 7 0 = 1 ,它并不固定 1,奇异值只是落进 1 附近的窄带、仍会摆动;
后 2 步切到 ( a , b , c ) = ( 2 , − 1.5 , 0.5 ) (a,b,c)=(2,-1.5,0.5) ( a , b , c ) = ( 2 , − 1 . 5 , 0 . 5 ) :p ( 1 ) = 2 − 1.5 + 0.5 = 1 p(1)=2-1.5+0.5=1 p ( 1 ) = 2 − 1 . 5 + 0 . 5 = 1 、p ′ ( 1 ) = a + 3 b + 5 c = 2 − 4.5 + 2.5 = 0 p'(1)=a+3b+5c=2-4.5+2.5=0 p ′ ( 1 ) = a + 3 b + 5 c = 2 − 4 . 5 + 2 . 5 = 0 ,在 σ = 1 \sigma=1 σ = 1 处取不动点且一阶导为零,二次收敛 ,把窄带内的奇异值精确锁到 1。
"先用激进系数快速逼近、再用锁定系数精确收敛"即 hybrid 之意:只用激进系数得不到精确正交,只用锁定系数对小奇异值收敛太慢。
Nesterov trick 与 RMS 重缩放
完整算法(报告 Algorithm 1;W ∈ R n × m W\in\mathbb R^{n\times m} W ∈ R n × m ,学习率 η \eta η 、动量 μ \mu μ 、权重衰减 λ \lambda λ 、缩放因子 γ \gamma γ ):
G t = ∇ W L t ( W t − 1 ) ▹ 梯度 M t = μ M t − 1 + G t ▹ 累积动量 O t ′ = H y b r i d N e w t o n S c h u l z ( μ M t + G t ) ▹ Nesterov trick + 正交化 O t = O t ′ ⋅ max ( n , m ) ⋅ γ ▹ RMS 重缩放 W t = W t − 1 ( 1 − η λ ) − η O t ▹ 权重衰减 + 更新 \begin{aligned}
G_t&=\nabla_W\mathcal L_t(W_{t-1}) &&\triangleright\ \text{梯度}\\
M_t&=\mu M_{t-1}+G_t &&\triangleright\ \text{累积动量}\\
O'_t&=\mathrm{HybridNewtonSchulz}(\mu M_t+G_t) &&\triangleright\ \text{Nesterov trick + 正交化}\\
O_t&=O'_t\cdot\sqrt{\max(n,m)}\cdot\gamma &&\triangleright\ \text{RMS 重缩放}\\
W_t&=W_{t-1}(1-\eta\lambda)-\eta O_t &&\triangleright\ \text{权重衰减 + 更新}
\end{aligned}
G t M t O t ′ O t W t = ∇ W L t ( W t − 1 ) = μ M t − 1 + G t = H y b r i d N e w t o n S c h u l z ( μ M t + G t ) = O t ′ ⋅ max ( n , m ) ⋅ γ = W t − 1 ( 1 − η λ ) − η O t ▹ 梯度 ▹ 累积动量 ▹ Nesterov trick + 正交化 ▹ RMS 重缩放 ▹ 权重衰减 + 更新
(a) Nesterov trick(第 3 行)。 正交化的对象不是普通动量 M t M_t M t ,而是前瞻方向 μ M t + G t \mu M_t+G_t μ M t + G t 。普通(重球)动量直接沿累积方向 M t M_t M t 更新,问题是梯度 G t G_t G t 在当前点 W t − 1 W_{t-1} W t − 1 测得,可动量大时这一步会移动到更靠前的位置,当前斜率到那里已不准确,容易过冲、震荡。Nesterov 的修正是"先顺动量探到前瞻点、在那里再测梯度",等价整理进缓冲后更新方向变为 μ M t + G t \mu M_t+G_t μ M t + G t 。展开
μ M t + G t = ( 1 + μ ) G t + μ 2 M t − 1 , \mu M_t+G_t=(1+\mu)G_t+\mu^2 M_{t-1},
μ M t + G t = ( 1 + μ ) G t + μ 2 M t − 1 ,
相比重球的 M t = G t + μ M t − 1 M_t=G_t+\mu M_{t-1} M t = G t + μ M t − 1 ,它给最新梯度更高权重 ( 1 + μ ) (1+\mu) ( 1 + μ ) 、把历史动量系数从 μ \mu μ 收紧到 μ 2 \mu^2 μ 2 。作用:若已越过极小点,前瞻点的梯度更早反向,提前抑制过冲——震荡更小、收敛更快更稳。
(b) RMS 重缩放(第 4 行)及如何消除形状依赖。 把更新的典型幅度用每元素均方根度量:R M S ( A ) = ∥ A ∥ F / n m \mathrm{RMS}(A)=\|A\|_F/\sqrt{nm} R M S ( A ) = ∥ A ∥ F / n m (∥ ⋅ ∥ F \|\cdot\|_F ∥ ⋅ ∥ F 为 Frobenius 范数)。正交化输出 O ′ = U V ⊤ O'=UV^\top O ′ = U V ⊤ 有 min ( n , m ) \min(n,m) min ( n , m ) 个奇异值且全为 1,故 ∥ U V ⊤ ∥ F 2 = ∑ i = 1 min ( n , m ) 1 = min ( n , m ) \|UV^\top\|_F^2=\sum_{i=1}^{\min(n,m)}1=\min(n,m) ∥ U V ⊤ ∥ F 2 = ∑ i = 1 m i n ( n , m ) 1 = min ( n , m ) ,于是
R M S ( O ′ ) = min ( n , m ) n m = 1 max ( n , m ) . \mathrm{RMS}(O')=\frac{\sqrt{\min(n,m)}}{\sqrt{nm}}=\frac{1}{\sqrt{\max(n,m)}} .
R M S ( O ′ ) = n m min ( n , m ) = max ( n , m ) 1 .
可见正交化更新固有的 RMS 是 1 / max ( n , m ) 1/\sqrt{\max(n,m)} 1 / max ( n , m ) ——只取决于矩阵两维里较大的那个,矩阵越大、RMS 越小 ,这就是形状依赖。乘上 max ( n , m ) ⋅ γ \sqrt{\max(n,m)}\cdot\gamma max ( n , m ) ⋅ γ 后,max ( n , m ) \sqrt{\max(n,m)} max ( n , m ) 恰好把这个因子抵消:
R M S ( O t ) = 1 max ( n , m ) ⋅ max ( n , m ) ⋅ γ = γ , \mathrm{RMS}(O_t)=\frac{1}{\sqrt{\max(n,m)}}\cdot\sqrt{\max(n,m)}\cdot\gamma=\gamma ,
R M S ( O t ) = max ( n , m ) 1 ⋅ max ( n , m ) ⋅ γ = γ ,
变成与矩阵尺寸无关的常数 γ \gamma γ (V4 取使 RMS ≈ 0.18 \approx0.18 ≈ 0 . 1 8 的值)。这样无论权重多大,Muon 更新幅度都对齐到同一尺度、并与 AdamW 一致,于是能直接复用 AdamW 调好的学习率与调度(第 5 行权重衰减亦采用 AdamW 式解耦形式 W ( 1 − η λ ) W(1-\eta\lambda) W ( 1 − η λ ) )。
为什么不用 QK-Clip
问题。 Muon 下注意力的 Query/Key 投影权重在训练中持续增大,使 softmax 之前的注意力 logit 失控、引发损失尖峰。记头 h h h 上 token i , j i,j i , j 的 logit、以及该头在一个 batch B \mathcal B B 上的最大 logit 为(d h d_h d h 为注意力头维):
S i j h = 1 d h q i h ⋅ k j h , q i h = x i W q h , k j h = x j W k h , S max h = max x ∈ B max i , j S i j h . S^h_{ij}=\frac{1}{\sqrt{d_h}}\,q^h_i\cdot k^h_j,\qquad q^h_i=x_iW_q^h,\quad k^h_j=x_jW_k^h,\qquad S^h_{\max}=\max_{x\in\mathcal B}\ \max_{i,j}S^h_{ij}.
S i j h = d h 1 q i h ⋅ k j h , q i h = x i W q h , k j h = x j W k h , S m a x h = x ∈ B max i , j max S i j h .
QK-Clip(Kimi K2 的对策)。 在每步 Muon 更新之后 ,逐头取该步前向传播中已记录的 最大 logit S max h S^h_{\max} S m a x h (前向算注意力时顺带取 max,几乎无额外开销);若超过阈值 τ \tau τ (K2 取 τ = 100 \tau=100 τ = 1 0 0 ),按比例缩小该头的 Q/K 投影权重,把最大 logit 钳回 τ \tau τ 。缩放因子与权重更新为
γ h = min ( 1 , τ S max h ) , W q h ← γ h α W q h , W k h ← γ h 1 − α W k h ( α = 0.5 ) . \gamma_h=\min\!\Big(1,\ \frac{\tau}{S^h_{\max}}\Big),\qquad
W_q^h\leftarrow\gamma_h^{\alpha}\,W_q^h,\quad
W_k^h\leftarrow\gamma_h^{1-\alpha}\,W_k^h\quad(\alpha=0.5).
γ h = min ( 1 , S m a x h τ ) , W q h ← γ h α W q h , W k h ← γ h 1 − α W k h ( α = 0 . 5 ) .
α = 0.5 \alpha=0.5 α = 0 . 5 时 W q h , W k h W_q^h,W_k^h W q h , W k h 各乘 γ h \sqrt{\gamma_h} γ h ,于是 q h ⋅ k h q^h\!\cdot k^h q h ⋅ k h 整体乘 γ h \gamma_h γ h 、最大 logit 变为 γ h S max h = τ \gamma_h S^h_{\max}=\tau γ h S m a x h = τ 。它不改变当前步的前向/反向,只是训练循环外对权重的事后修正;未超阈值的头 γ h = 1 \gamma_h=1 γ h = 1 、保持不变。
V4 的做法:Q/K RMSNorm。 V4 在核心注意力前,对每个 query 头与压缩 KV 条目各做一次 RMSNorm(见 §2.3「其他细节」)。对向量 x ∈ R c x\in\mathbb R^{c} x ∈ R c (c c c 为头维)、可学习增益 g ∈ R c g\in\mathbb R^{c} g ∈ R c :
R M S N o r m ( x ) = x 1 c ∑ j = 1 c x j 2 + ϵ ⊙ g , q ^ = R M S N o r m ( q ) , k ^ = R M S N o r m ( k ) , \mathrm{RMSNorm}(x)=\frac{x}{\sqrt{\tfrac1c\sum_{j=1}^{c}x_j^2+\epsilon}}\odot g,\qquad
\hat q=\mathrm{RMSNorm}(q),\quad \hat k=\mathrm{RMSNorm}(k),
R M S N o r m ( x ) = c 1 ∑ j = 1 c x j 2 + ϵ x ⊙ g , q ^ = R M S N o r m ( q ) , k ^ = R M S N o r m ( k ) ,
logit 改用 q ^ ⋅ k ^ / c \hat q\cdot\hat k/\sqrt{c} q ^ ⋅ k ^ / c 。归一化把 ∥ q ^ ∥ , ∥ k ^ ∥ \lVert\hat q\rVert,\lVert\hat k\rVert ∥ q ^ ∥ , ∥ k ^ ∥ 固定到由增益 g g g 决定的尺度,由 Cauchy–Schwarz ∣ q ^ ⋅ k ^ ∣ ≤ ∥ q ^ ∥ ∥ k ^ ∥ |\hat q\cdot\hat k|\le\lVert\hat q\rVert\,\lVert\hat k\rVert ∣ q ^ ⋅ k ^ ∣ ≤ ∥ q ^ ∥ ∥ k ^ ∥ ,logit 被天然限幅——无论投影权重训练中如何增大,归一化都在算分数前把尺度拉回,故 logit 不会失控。
两者对比。 都为限制注意力 logit 的尺度,但作用层面相反:
QK-Clip(Kimi K2)
Q/K RMSNorm(V4)
作用对象
权重 W q h , W k h W_q^h,W_k^h W q h , W k h
激活向量 q , k q,k q , k
时机
优化器更新后、训练循环外的额外步骤
前向传播内,每次注意力都做
触发
仅当 S max h > τ S^h_{\max}>\tau S m a x h > τ 才缩放(reactive)
始终归一化(always-on)
需要的量
逐头统计 S max h S^h_{\max} S m a x h 、阈值 τ \tau τ
可学习增益 g g g ;无需统计 logit
限幅依据
把 S max h S^h_{\max} S m a x h 钳到 ≤ τ \le\tau ≤ τ
∥ q ^ ∥ , ∥ k ^ ∥ \lVert\hat q\rVert,\lVert\hat k\rVert ∥ q ^ ∥ , ∥ k ^ ∥ 固定 ⇒ \Rightarrow ⇒ Cauchy–Schwarz
对已学注意力的影响
直接缩小权重,会改动已学到的注意力分布
不裁剪权重;归一化是架构的一部分
推理时
训练手段,推理时权重已固定
架构组件,推理时仍执行
V4 因此把抑制 logit 增长的机制放进架构 (前置归一化),而非放进优化器 (事后裁剪权重):前者不改变学到的注意力分布、也不必逐头跟踪 S max h S^h_{\max} S m a x h 与缩放系数。故 V4 的 Muon 不含 QK-Clip。
优化器分配:哪些用 Muon、哪些用 AdamW
划分规则(报告 §2.4 Basic Configurations、§4.2.2):AdamW 只负责 embedding、prediction head、mHC 的 static biases 与 gating factors、所有 RMSNorm 权重;其余参数(绝大多数,即所有二维权重矩阵)全部用 Muon 。这就是标准 Muon 配方——只对二维隐藏矩阵正交化,embedding/输出头/norm/一维参数留给 AdamW。Pro 与 Flash 的分配一致,仅数值超参不同。
模块
优化器
说明
Attention projections(CSA / HCA) :KV 压缩 W a K V , W b K V , W K V W^{aKV},W^{bKV},W^{KV} W a K V , W b K V , W K V 、query 降/升维 W D Q , W U Q W^{DQ},W^{UQ} W D Q , W U Q 、Lightning Indexer W I U Q , W w W^{IUQ},W^{w} W I U Q , W w 、分组输出 W G , W O W^{G},W^{O} W G , W O
Muon
二维隐藏层权重矩阵
FFN / MoE experts :shared + routed expert 权重(所有 Transformer block 均为 MoE)
Muon
二维隐藏层权重矩阵
MTP module :自带 Transformer block(TRM)+ 投影矩阵 M M M
Muon
二维隐藏层权重矩阵
其余所有二维线性权重(含 mHC dynamic mixing 矩阵)
Muon
报告原文 “all other modules” 全归 Muon
Embedding module
AdamW
词嵌入查找表,非隐藏线性映射
Prediction head (输出 / LM head)
AdamW
映射到词表的输出层
mHC 的 static biases 与 gating factors
AdamW
报告 §2.4 显式划归 AdamW
所有 RMSNorm 权重(gain)
AdamW
一维参数,Muon 只作用于矩阵
超参(两模型一致):Muon — momentum μ = 0.95 \mu=0.95 μ = 0 . 9 5 、weight decay 0.1 0.1 0 . 1 、更新矩阵 RMS 重缩放到 0.18 0.18 0 . 1 8 (以复用 AdamW 学习率);AdamW — β 1 = 0.9 , β 2 = 0.95 , ε = 1 0 − 20 \beta_1=0.9,\ \beta_2=0.95,\ \varepsilon=10^{-20} β 1 = 0 . 9 , β 2 = 0 . 9 5 , ε = 1 0 − 2 0 、weight decay 0.1 0.1 0 . 1 。另注:MoE 路由的负载均衡偏置 b i b_i b i 由 §2.1 的启发式规则更新(无梯度),既不归 Muon 也不归 AdamW。
Muon 的完整原理(正交化为何有效、谱范数最速下降、与 Shampoo 的关系、规模化与采用情况)见本博客专篇 Muon 优化器 。
三、General Infrastructures 🚧 TODO
对应报告 §3。涵盖:细粒度专家并行的通信-计算重叠(融合 MegaMoE kernel)、TileLang DSL(Host Codegen、Z3 形式化整数分析)、batch-invariant 与确定性 kernel(逐比特可复现)、训练框架(面向 Muon 的混合 ZeRO、mHC 重计算、两阶段上下文并行、张量级激活检查点)、推理框架(异构 KV cache 布局、磁盘 KV 存储复用共享前缀)。
TODO:后续补充。
四、Pre-Training 🚧 TODO
对应报告 §4。涵盖:数据构建(强调长文档、32T+ token)、预训练设置(4K→16K→64K→1M 序列长度课程、稀疏注意力热身)、训练稳定性手段(Anticipatory Routing 前瞻路由、SwiGLU Clamping)、基座模型评测(V4-Flash-Base 反超 V3.2-Base)。
TODO:后续补充。
五、Post-Training 🚧 TODO
对应报告 §5。涵盖:领域专家训练(SFT + GRPO)+ 在线策略蒸馏(OPD,取代混合 RL)、三档推理力度(Non-think / Think High / Think Max)、生成式奖励模型(GRM)、FP4 量化感知训练、可抢占容错 rollout、智能体沙箱 DSec、标准与真实任务评测。
TODO:后续补充。
附注