[Perf] add packed recurrent fast path for decode (#36596)

Signed-off-by: hdj <1293066020@qq.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
caozuoba
2026-03-12 19:01:57 +08:00
committed by GitHub
parent 06e0bc21d2
commit 9e19f8338b
5 changed files with 402 additions and 4 deletions

View File

@@ -0,0 +1,98 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
import torch
from vllm.model_executor.layers.fla.ops import (
fused_recurrent_gated_delta_rule,
fused_recurrent_gated_delta_rule_packed_decode,
)
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Need CUDA device")
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
@pytest.mark.parametrize("strided_mixed_qkv", [False, True])
def test_fused_recurrent_packed_decode_matches_reference(
dtype: torch.dtype, strided_mixed_qkv: bool
):
torch.manual_seed(0)
# Small but representative GDN config (Qwen3Next defaults are K=128, V=128).
B = 32
H = 4
HV = 8 # grouped value attention: HV must be divisible by H
K = 128
V = 128
qkv_dim = 2 * (H * K) + (HV * V)
device = torch.device("cuda")
if strided_mixed_qkv:
# Simulate a packed view into a larger projection buffer:
# mixed_qkv.stride(0) > mixed_qkv.shape[1]
proj = torch.randn((B, qkv_dim + 64), device=device, dtype=dtype)
mixed_qkv = proj[:, :qkv_dim]
else:
mixed_qkv = torch.randn((B, qkv_dim), device=device, dtype=dtype)
a = torch.randn((B, HV), device=device, dtype=dtype)
b = torch.randn((B, HV), device=device, dtype=dtype)
A_log = torch.randn((HV,), device=device, dtype=dtype)
dt_bias = torch.randn((HV,), device=device, dtype=dtype)
# Continuous batching indices (include PAD_SLOT_ID=-1 cases).
ssm_state_indices = torch.arange(B, device=device, dtype=torch.int32)
ssm_state_indices[-3:] = -1
state0 = torch.randn((B, HV, V, K), device=device, dtype=dtype)
state_ref = state0.clone()
state_packed = state0.clone()
out_packed = torch.empty((B, 1, HV, V), device=device, dtype=dtype)
# Reference path: materialize contiguous Q/K/V + explicit gating.
q, k, v = torch.split(mixed_qkv, [H * K, H * K, HV * V], dim=-1)
q = q.view(B, H, K).unsqueeze(1).contiguous()
k = k.view(B, H, K).unsqueeze(1).contiguous()
v = v.view(B, HV, V).unsqueeze(1).contiguous()
x = a.float() + dt_bias.float()
softplus_x = torch.where(
x <= 20.0, torch.log1p(torch.exp(torch.clamp(x, max=20.0))), x
)
g = (-torch.exp(A_log.float()) * softplus_x).unsqueeze(1)
beta = torch.sigmoid(b.float()).to(dtype).unsqueeze(1)
out_ref, state_ref = fused_recurrent_gated_delta_rule(
q=q,
k=k,
v=v,
g=g,
beta=beta,
scale=K**-0.5,
initial_state=state_ref,
inplace_final_state=True,
cu_seqlens=None,
ssm_state_indices=ssm_state_indices,
use_qk_l2norm_in_kernel=True,
)
# Packed path: fused gating + recurrent directly from packed mixed_qkv.
fused_recurrent_gated_delta_rule_packed_decode(
mixed_qkv=mixed_qkv,
a=a,
b=b,
A_log=A_log,
dt_bias=dt_bias,
scale=K**-0.5,
initial_state=state_packed,
out=out_packed,
ssm_state_indices=ssm_state_indices,
use_qk_l2norm_in_kernel=True,
)
atol = 2e-2 if dtype != torch.float32 else 1e-4
rtol = 1e-2 if dtype != torch.float32 else 1e-4
torch.testing.assert_close(out_packed, out_ref, rtol=rtol, atol=atol)
torch.testing.assert_close(state_packed, state_ref, rtol=rtol, atol=atol)

View File

@@ -96,6 +96,7 @@ if TYPE_CHECKING:
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_SKIP_P2P_CHECK: bool = False VLLM_SKIP_P2P_CHECK: bool = False
VLLM_DISABLED_KERNELS: list[str] = [] VLLM_DISABLED_KERNELS: list[str] = []
VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE: bool = True
VLLM_DISABLE_PYNCCL: bool = False VLLM_DISABLE_PYNCCL: bool = False
VLLM_USE_OINK_OPS: bool = False VLLM_USE_OINK_OPS: bool = False
VLLM_ROCM_USE_AITER: bool = False VLLM_ROCM_USE_AITER: bool = False
@@ -899,6 +900,9 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_DISABLED_KERNELS": lambda: [] "VLLM_DISABLED_KERNELS": lambda: []
if "VLLM_DISABLED_KERNELS" not in os.environ if "VLLM_DISABLED_KERNELS" not in os.environ
else os.environ["VLLM_DISABLED_KERNELS"].split(","), else os.environ["VLLM_DISABLED_KERNELS"].split(","),
"VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE": lambda: bool(
int(os.getenv("VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE", "1"))
),
# Disable pynccl (using torch.distributed instead) # Disable pynccl (using torch.distributed instead)
"VLLM_DISABLE_PYNCCL": lambda: ( "VLLM_DISABLE_PYNCCL": lambda: (
os.getenv("VLLM_DISABLE_PYNCCL", "False").lower() in ("true", "1") os.getenv("VLLM_DISABLE_PYNCCL", "False").lower() in ("true", "1")

View File

@@ -7,7 +7,10 @@
# the following copyright notice: # the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
from .chunk import chunk_gated_delta_rule from .chunk import chunk_gated_delta_rule
from .fused_recurrent import fused_recurrent_gated_delta_rule from .fused_recurrent import (
fused_recurrent_gated_delta_rule,
fused_recurrent_gated_delta_rule_packed_decode,
)
from .fused_sigmoid_gating import fused_sigmoid_gating_delta_rule_update from .fused_sigmoid_gating import fused_sigmoid_gating_delta_rule_update
from .layernorm_guard import RMSNormGated from .layernorm_guard import RMSNormGated
@@ -15,5 +18,6 @@ __all__ = [
"RMSNormGated", "RMSNormGated",
"chunk_gated_delta_rule", "chunk_gated_delta_rule",
"fused_recurrent_gated_delta_rule", "fused_recurrent_gated_delta_rule",
"fused_recurrent_gated_delta_rule_packed_decode",
"fused_sigmoid_gating_delta_rule_update", "fused_sigmoid_gating_delta_rule_update",
] ]

View File

@@ -252,6 +252,231 @@ def fused_recurrent_gated_delta_rule_fwd(
return o, final_state return o, final_state
@triton.jit
def fused_recurrent_gated_delta_rule_packed_decode_kernel(
mixed_qkv,
a,
b,
A_log,
dt_bias,
o,
h0,
ht,
ssm_state_indices,
scale,
stride_mixed_qkv_tok: tl.constexpr,
stride_a_tok: tl.constexpr,
stride_b_tok: tl.constexpr,
stride_init_state_token: tl.constexpr,
stride_final_state_token: tl.constexpr,
stride_indices_seq: tl.constexpr,
H: tl.constexpr,
HV: tl.constexpr,
K: tl.constexpr,
V: tl.constexpr,
BK: tl.constexpr,
BV: tl.constexpr,
SOFTPLUS_THRESHOLD: tl.constexpr,
USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
):
i_v, i_nh = tl.program_id(0), tl.program_id(1)
i_n, i_hv = i_nh // HV, i_nh % HV
i_h = i_hv // (HV // H)
o_k = tl.arange(0, BK)
o_v = i_v * BV + tl.arange(0, BV)
mask_k = o_k < K
mask_v = o_v < V
mask_h = mask_v[:, None] & mask_k[None, :]
state_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq).to(tl.int64)
p_o = o + (i_n * HV + i_hv) * V + o_v
if state_idx < 0:
zero = tl.zeros([BV], dtype=tl.float32).to(p_o.dtype.element_ty)
tl.store(p_o, zero, mask=mask_v)
return
p_h0 = h0 + state_idx * stride_init_state_token
p_h0 = p_h0 + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
b_h = tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
p_mixed = mixed_qkv + i_n * stride_mixed_qkv_tok
q_off = i_h * K + o_k
k_off = (H * K) + i_h * K + o_k
v_off = (2 * H * K) + i_hv * V + o_v
b_q = tl.load(p_mixed + q_off, mask=mask_k, other=0).to(tl.float32)
b_k = tl.load(p_mixed + k_off, mask=mask_k, other=0).to(tl.float32)
b_v = tl.load(p_mixed + v_off, mask=mask_v, other=0).to(tl.float32)
if USE_QK_L2NORM_IN_KERNEL:
b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
b_q = b_q * scale
a_val = tl.load(a + i_n * stride_a_tok + i_hv).to(tl.float32)
b_val = tl.load(b + i_n * stride_b_tok + i_hv).to(tl.float32)
A_log_val = tl.load(A_log + i_hv).to(tl.float32)
dt_bias_val = tl.load(dt_bias + i_hv).to(tl.float32)
x = a_val + dt_bias_val
softplus_x = tl.where(x <= SOFTPLUS_THRESHOLD, tl.log(1.0 + tl.exp(x)), x)
g_val = -tl.exp(A_log_val) * softplus_x
beta_val = tl.sigmoid(b_val).to(b.dtype.element_ty).to(tl.float32)
b_h *= exp(g_val)
b_v -= tl.sum(b_h * b_k[None, :], 1)
b_v *= beta_val
b_h += b_v[:, None] * b_k[None, :]
b_o = tl.sum(b_h * b_q[None, :], 1)
tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
p_ht = ht + state_idx * stride_final_state_token
p_ht = p_ht + i_hv * V * K + o_v[:, None] * K + o_k[None, :]
tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
def fused_recurrent_gated_delta_rule_packed_decode(
mixed_qkv: torch.Tensor,
a: torch.Tensor,
b: torch.Tensor,
A_log: torch.Tensor,
dt_bias: torch.Tensor,
scale: float,
initial_state: torch.Tensor,
out: torch.Tensor,
ssm_state_indices: torch.Tensor,
use_qk_l2norm_in_kernel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
if mixed_qkv.ndim != 2:
raise ValueError(
f"`mixed_qkv` must be a 2D tensor (got ndim={mixed_qkv.ndim})."
)
if mixed_qkv.stride(-1) != 1:
raise ValueError("`mixed_qkv` must be contiguous in the last dim.")
if a.ndim != 2 or b.ndim != 2:
raise ValueError(
f"`a` and `b` must be 2D tensors (got a.ndim={a.ndim}, b.ndim={b.ndim})."
)
if a.stride(-1) != 1 or b.stride(-1) != 1:
raise ValueError("`a`/`b` must be contiguous in the last dim.")
if A_log.ndim != 1 or dt_bias.ndim != 1:
raise ValueError("`A_log`/`dt_bias` must be 1D tensors.")
if A_log.stride(0) != 1 or dt_bias.stride(0) != 1:
raise ValueError("`A_log`/`dt_bias` must be contiguous.")
if ssm_state_indices.ndim != 1:
raise ValueError(
f"`ssm_state_indices` must be 1D for packed decode (got ndim={ssm_state_indices.ndim})."
)
if not out.is_contiguous():
raise ValueError("`out` must be contiguous.")
dev = mixed_qkv.device
if (
a.device != dev
or b.device != dev
or A_log.device != dev
or dt_bias.device != dev
or initial_state.device != dev
or out.device != dev
or ssm_state_indices.device != dev
):
raise ValueError("All inputs must be on the same device.")
B = mixed_qkv.shape[0]
if a.shape[0] != B or b.shape[0] != B:
raise ValueError(
"Mismatched batch sizes: "
f"mixed_qkv.shape[0]={B}, a.shape[0]={a.shape[0]}, b.shape[0]={b.shape[0]}."
)
if ssm_state_indices.shape[0] != B:
raise ValueError(
f"`ssm_state_indices` must have shape [B] (got {tuple(ssm_state_indices.shape)}; expected ({B},))."
)
if initial_state.ndim != 4:
raise ValueError(
f"`initial_state` must be a 4D tensor (got ndim={initial_state.ndim})."
)
if initial_state.stride(-1) != 1:
raise ValueError("`initial_state` must be contiguous in the last dim.")
HV, V, K = initial_state.shape[-3:]
if a.shape[1] != HV or b.shape[1] != HV:
raise ValueError(
f"`a`/`b` must have shape [B, HV] with HV={HV} (got a.shape={tuple(a.shape)}, b.shape={tuple(b.shape)})."
)
if A_log.numel() != HV or dt_bias.numel() != HV:
raise ValueError(
f"`A_log` and `dt_bias` must have {HV} elements (got A_log.numel()={A_log.numel()}, dt_bias.numel()={dt_bias.numel()})."
)
if out.shape != (B, 1, HV, V):
raise ValueError(
f"`out` must have shape {(B, 1, HV, V)} (got out.shape={tuple(out.shape)})."
)
qkv_dim = mixed_qkv.shape[1]
qk_dim = qkv_dim - HV * V
if qk_dim <= 0 or qk_dim % 2 != 0:
raise ValueError(
f"Invalid packed `mixed_qkv` last dim={qkv_dim} for HV={HV}, V={V}."
)
q_dim = qk_dim // 2
if q_dim % K != 0:
raise ValueError(f"Invalid packed Q size {q_dim}: must be divisible by K={K}.")
H = q_dim // K
if H <= 0 or HV % H != 0:
raise ValueError(
f"Invalid head config inferred from mixed_qkv: H={H}, HV={HV}."
)
BK = triton.next_power_of_2(K)
if triton.cdiv(K, BK) != 1:
raise ValueError(
f"Packed decode kernel only supports NK=1 (got K={K}, BK={BK})."
)
BV = min(triton.next_power_of_2(V), 32)
num_stages = 3
num_warps = 1
stride_mixed_qkv_tok = mixed_qkv.stride(0)
stride_a_tok = a.stride(0)
stride_b_tok = b.stride(0)
stride_init_state_token = initial_state.stride(0)
stride_final_state_token = initial_state.stride(0)
stride_indices_seq = ssm_state_indices.stride(0)
NV = triton.cdiv(V, BV)
grid = (NV, B * HV)
fused_recurrent_gated_delta_rule_packed_decode_kernel[grid](
mixed_qkv=mixed_qkv,
a=a,
b=b,
A_log=A_log,
dt_bias=dt_bias,
o=out,
h0=initial_state,
ht=initial_state,
ssm_state_indices=ssm_state_indices,
scale=scale,
stride_mixed_qkv_tok=stride_mixed_qkv_tok,
stride_a_tok=stride_a_tok,
stride_b_tok=stride_b_tok,
stride_init_state_token=stride_init_state_token,
stride_final_state_token=stride_final_state_token,
stride_indices_seq=stride_indices_seq,
H=H,
HV=HV,
K=K,
V=V,
BK=BK,
BV=BV,
SOFTPLUS_THRESHOLD=20.0,
USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
num_warps=num_warps,
num_stages=num_stages,
)
return out, initial_state
class FusedRecurrentFunction(torch.autograd.Function): class FusedRecurrentFunction(torch.autograd.Function):
@staticmethod @staticmethod
def forward( def forward(

View File

@@ -10,6 +10,7 @@ from einops import rearrange
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from vllm import envs
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import ( from vllm.config import (
CacheConfig, CacheConfig,
@@ -34,6 +35,7 @@ from vllm.model_executor.layers.fla.ops import (
chunk_gated_delta_rule as fla_chunk_gated_delta_rule, chunk_gated_delta_rule as fla_chunk_gated_delta_rule,
) )
from vllm.model_executor.layers.fla.ops import ( from vllm.model_executor.layers.fla.ops import (
fused_recurrent_gated_delta_rule_packed_decode,
fused_sigmoid_gating_delta_rule_update, fused_sigmoid_gating_delta_rule_update,
) )
from vllm.model_executor.layers.fla.ops.chunk import l2norm_fwd from vllm.model_executor.layers.fla.ops.chunk import l2norm_fwd
@@ -474,6 +476,9 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
) )
self.chunk_gated_delta_rule = ChunkGatedDeltaRule() self.chunk_gated_delta_rule = ChunkGatedDeltaRule()
self.enable_packed_recurrent_decode = (
envs.VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE
)
compilation_config = get_current_vllm_config().compilation_config compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context: if prefix in compilation_config.static_forward_context:
@@ -747,9 +752,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
a: torch.Tensor, a: torch.Tensor,
core_attn_out: torch.Tensor, core_attn_out: torch.Tensor,
): ):
"""
Core attention computation (called by custom op).
"""
forward_context = get_forward_context() forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata attn_metadata: AttentionMetadata = forward_context.attn_metadata
@@ -762,6 +764,22 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
assert isinstance(attn_metadata, dict) assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix] attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, GDNAttentionMetadata) assert isinstance(attn_metadata, GDNAttentionMetadata)
if (
self.enable_packed_recurrent_decode
and attn_metadata.spec_sequence_masks is None
and attn_metadata.num_prefills == 0
and attn_metadata.num_decodes > 0
):
return self._forward_core_decode_non_spec(
mixed_qkv=mixed_qkv,
b=b,
a=a,
core_attn_out=core_attn_out,
attn_metadata=attn_metadata,
virtual_engine=forward_context.virtual_engine,
)
has_initial_state = attn_metadata.has_initial_state has_initial_state = attn_metadata.has_initial_state
spec_query_start_loc = attn_metadata.spec_query_start_loc spec_query_start_loc = attn_metadata.spec_query_start_loc
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
@@ -946,6 +964,55 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
else: else:
core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0) core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)
def _forward_core_decode_non_spec(
self,
mixed_qkv: torch.Tensor,
b: torch.Tensor,
a: torch.Tensor,
core_attn_out: torch.Tensor,
attn_metadata: GDNAttentionMetadata,
virtual_engine: int,
):
"""
Core attention computation with a packed non-spec decode fast path.
"""
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
self_kv_cache = self.kv_cache[virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens
mixed_qkv = mixed_qkv[:num_actual_tokens]
b = b[:num_actual_tokens]
a = a[:num_actual_tokens]
conv_weights = self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
)
mixed_qkv_non_spec = causal_conv1d_update(
mixed_qkv,
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=non_spec_state_indices_tensor[:num_actual_tokens],
validate_data=False,
)
out_buf = core_attn_out[:num_actual_tokens].unsqueeze(1)
fused_recurrent_gated_delta_rule_packed_decode(
mixed_qkv=mixed_qkv_non_spec,
a=a,
b=b,
A_log=self.A_log,
dt_bias=self.dt_bias,
scale=self.head_k_dim**-0.5,
initial_state=ssm_state,
out=out_buf,
ssm_state_indices=non_spec_state_indices_tensor[:num_actual_tokens],
use_qk_l2norm_in_kernel=True,
)
return
class Qwen3NextAttention(nn.Module): class Qwen3NextAttention(nn.Module):
def __init__( def __init__(