AI

vllm 中的 GDN

Posted by w@hidva.com on November 23, 2025

随着 Qwen3-Next PD 分离机制的陆续上线, 借此机会系统梳理一下 Gated Delta Rule 的计算流程以及 vllm 中的实现.

recurrent

如 GDN 论文 “Gated Delta Networks: Improving Mamba2 with Delta Rule” 所示, Gated Delta Rule 有 recurrent 与 chunk 两种形式, 前者适用与推理的 decode 阶段, 后者则适用于推理的 prefill 以及训练阶段. 简单起见, 我们先看下 recurrent 实现, 对应的计算公式:

\[\begin{align} \mathrm{S}_t &= \alpha_t \mathrm{S}_{t-1} (\mathrm{I} - \beta_t k_t k_t^{\mathrm{T}}) + \beta_t v_t k_t^{\mathrm{T}} \\ o_t &= \mathrm{S}_t q_t \end{align}\]

入口在 fused_recurrent_gated_delta_rule.

def fused_recurrent_gated_delta_rule(
    q: torch.Tensor,  # [B, T, H, K]
    k: torch.Tensor,  # [B, T, H, K]
    v: torch.Tensor,  # [B, T, HV, V]
    g: torch.Tensor,  # [B, T, HV]
    beta: torch.Tensor = None,  # [B, T, HV]
    scale: float = None,
    initial_state: torch.Tensor = None,  # kvcache [num_blocks, HV, K, V]
    inplace_final_state: bool = True,
    cu_seqlens: torch.LongTensor | None = None,  # [N + 1]
    ssm_state_indices: torch.Tensor | None = None,  #  [N, num_sepc + 1]
    num_accepted_tokens: torch.Tensor | None = None, # [N]
    use_qk_l2norm_in_kernel: bool = True,
)

参数说明 (在 vLLM 连续批处理下,B 始终为 1):

  • T 本次 step 参与计算的 token 数.
  • H 对应着 num_k_heads / tp_size, HV 对应着 num_v_heads / tp_size
  • K 对应着 head_k_dim. V 对应着 head_v_dim.

ssm_state_indices[R]: list[int] 存放着请求 R 在 GDN layer 下使用的 kvcache block id. num_accepted_tokens 以 gamma=3 举例说明这个参数语义, 在 Step S1 中, 请求 R 输入 token 分别是 sample_token, draft_token1, draft_token2, draft_token3; draft_token1, draft_token2, draft_token3 为上一个 Step S0 生成的 draft token, 在 Step S1 进行验证. 在当前层 L 下 R 会占用 1 + gamma 个 kvcache block:

  1. Step S1, target model forward, 计算 sample_token, draft_token1, draft_token2, draft_token3 的 ssm state, 之后分别存放在指定的 kvcache block 中.
  2. Step S1, verify, 确认接受 draft_token1, draft_token2; 此时会 GPUModelRunner 侧记录 R 的状态: num_accepted_tokens[R] = 1 + 2.
  3. Step S2, target model forward, 对于 R 来说, 将从 initial_state[ssm_state_indices[R][num_accepted_tokens[R] - 1]] 中读取 init state, 即 init state 为 draft_token2 对应的 ssm_state.

对于一个特定请求 R 其在一个特定 vhead 下 ssm_state shape=(K, V); 一个 fused_recurrent_gated_delta_rule_fwd_kernel triton program 只负责计算 (BK=K, BV) 那一小部分.

  BV
+-----+-----+-----+-----+   -
|     |     |     |     |   |
|     |     |     |     |   |
|     |     |     |     |   BK = K
|     |     |     |     |   |
|     |     |     |     |   |
+-----+-----+-----+-----+   -

fused_recurrent_gated_delta_rule_fwd_kernel 中局部变量 i_hv 记录着当前 triton program 负责处理的 vhead. 在 Qwen3-Next 中, gdn 层 k head num=16, v head num = 32; 即 vhead v0, v1 共用着 khead k0; vhead v2, v3 共用着 khead k1; 以此类推. i_h 记录着当前 vhead 对应的 khead.

g, 这里 g 语义上对应着 GDN 论文中的 $\alpha$, 但其值并不是 $\alpha$, 而是 $\log(\alpha)$! 即假设某个特定请求 R 在某一 vhead 下 $\alpha = [\alpha_1, \alpha_2, \cdots, \alpha_T]$, 那么 $g = [\log(\alpha_1), \log(\alpha_2), \cdots, \log(\alpha_T)]$. 所以在计算 $\alpha_t \mathrm{S}_{t-1}$ 使用的是 b_h *= exp(b_g) 而不是 b_h *= b_g!

理解了这些, 再看 fused_recurrent_gated_delta_rule_fwd_kernel 实现就很容易理解了, 其就是朴素地实现了 Gated Delta Rule recurrent 计算公式. 这里就不详细说明了.

chunk

从 GDN 论文中可知, B.T.W GDN 论文中有一些 typo. 如下 Gated Delta Rule chunk 对应的计算公式根据具体实现总结:

\[\begin{align} \mathrm{A}_{[t]} &= \left[ \mathrm{I} + \mathrm{diag}({\beta_{[t]}}) (\mathrm{K}_{[t]} \mathrm{K}_{[t]}^{\mathrm{T}} \odot \Gamma_{[t]} \odot \mathrm{M}) \right]^{-1} \\ \mathrm{T}_{[t]} &= \mathrm{A}_{[t]} \mathrm{diag}({\beta_{[t]}}) \in \mathrm{R}^{C \times C} \\ \overleftarrow{\mathrm{W}_{[t]}} &= \mathrm{diag}(\gamma_{[t]}) \mathrm{W}_{[t]} = \mathrm{diag}(\gamma_{[t]}) \mathrm{T}_{[t]} \mathrm{K}_{[t]} \\ \mathrm{diag}(\gamma_{[t]}) &= \begin{bmatrix} \gamma_{[t]}^{1} & & & \\ & \gamma_{[t]}^{2} & & \\ & & \ddots & \\ & & & \gamma_{[t]}^{C} \end{bmatrix} \\ \widetilde{\mathrm{U}_{[t]}} &= \mathrm{T}_{[t]} \mathrm{V}_{[t]} \\ \mathrm{Vnew}_{[t]} &= \widetilde{\mathrm{U}_{[t]}} - \overleftarrow{\mathrm{W}_{[t]}} \mathrm{S}^{\mathrm{T}}_{[t]} \\ \mathrm{S}_{[t+1]} &= \overrightarrow{\mathrm{S}_{[t]}} + \mathrm{Vnew}_{[t]}^{\mathrm{T}} \overrightarrow{\mathrm{K}_{[t]}} \\ \overrightarrow{\mathrm{K}_{[t]}} &= \begin{pmatrix} \overrightarrow{k_{[t]}^1}^{\mathrm{T}} \\ \overrightarrow{k_{[t]}^2}^{\mathrm{T}} \\ \vdots \\ \overrightarrow{k_{[t]}^C}^{\mathrm{T}} \end{pmatrix} = \begin{pmatrix} \frac{\gamma_{[t]}^C}{\gamma_{[t]}^1} {k_{[t]}^1}^{\mathrm{T}} \\ \frac{\gamma_{[t]}^C}{\gamma_{[t]}^2} {k_{[t]}^2}^{\mathrm{T}} \\ \vdots \\ \frac{\gamma_{[t]}^C}{\gamma_{[t]}^C} {k_{[t]}^C}^{\mathrm{T}} \end{pmatrix} \in \mathrm{R}^{C \times d_k}, k_{[t]}^1 \in \mathrm{R}^{d_k \times 1} \\ \mathrm{O}_{[t]} &= \overleftarrow{\mathrm{Q}_{[t]}} \mathrm{S}^{\mathrm{T}}_{[t]} + \left( \mathrm{Q}_{[t]} \mathrm{K}^{\mathrm{T}}_{[t]} \odot \Gamma_{[t]} \odot \mathrm{M} \right) \mathrm{Vnew}_{[t]} \\ \overleftarrow{\mathrm{Q}_{[t]}} &= \begin{pmatrix} \overleftarrow{q_{[t]}^1}^{\mathrm{T}} \\ \overleftarrow{q_{[t]}^2}^{\mathrm{T}} \\ \vdots \\ \overleftarrow{q_{[t]}^C}^{\mathrm{T}} \end{pmatrix} = \begin{pmatrix} \gamma_{[t]}^1 {q_{[t]}^1}^{\mathrm{T}} \\ \gamma_{[t]}^2 {q_{[t]}^2}^{\mathrm{T}} \\ \vdots \\ \gamma_{[t]}^C {q_{[t]}^C}^{\mathrm{T}} \end{pmatrix} \in \mathrm{R}^{C \times d_k}, q_{[t]}^1 \in \mathrm{R}^{d_k \times 1} \\ \end{align}\]

chunk 实现对应入口是 chunk_gated_delta_rule, 整体执行过程是:

  1. 将请求输入 token 切分为 chunk 之后, 并行 对所有 chunk 计算对应的 $\overleftarrow{\mathrm{W}_{[t]}}$, $\widetilde{\mathrm{U}_{[t]}}$
  2. 串行 地处理每个 chunk, 计算每一个 chunk 对应的 $\mathrm{Vnew}_{[t]}$, $\mathrm{S}_{[t+1]}$ 并保存起来.
  3. 并行 地处理每个 chunk, 计算每一个 chunk 对应的 $\mathrm{O}_{[t]}$

理解了这些, 再看 chunk_gated_delta_rule 链路就很容易理解了, 其就是朴素地实现了如上 Gated Delta Rule chunk 计算公式. 这里就不详细说明了.