[Perf] add packed recurrent fast path for decode (#36596)

Signed-off-by: hdj <1293066020@qq.com>
Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
caozuoba
2026-03-12 19:01:57 +08:00
committed by GitHub
parent 06e0bc21d2
commit 9e19f8338b
5 changed files with 402 additions and 4 deletions

View File

@@ -10,6 +10,7 @@ from einops import rearrange
from torch import nn
from transformers.activations import ACT2FN
from vllm import envs
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (
CacheConfig,
@@ -34,6 +35,7 @@ from vllm.model_executor.layers.fla.ops import (
chunk_gated_delta_rule as fla_chunk_gated_delta_rule,
)
from vllm.model_executor.layers.fla.ops import (
fused_recurrent_gated_delta_rule_packed_decode,
fused_sigmoid_gating_delta_rule_update,
)
from vllm.model_executor.layers.fla.ops.chunk import l2norm_fwd
@@ -474,6 +476,9 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
)
self.chunk_gated_delta_rule = ChunkGatedDeltaRule()
self.enable_packed_recurrent_decode = (
envs.VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE
)
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
@@ -747,9 +752,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
a: torch.Tensor,
core_attn_out: torch.Tensor,
):
"""
Core attention computation (called by custom op).
"""
forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata
@@ -762,6 +764,22 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, GDNAttentionMetadata)
if (
self.enable_packed_recurrent_decode
and attn_metadata.spec_sequence_masks is None
and attn_metadata.num_prefills == 0
and attn_metadata.num_decodes > 0
):
return self._forward_core_decode_non_spec(
mixed_qkv=mixed_qkv,
b=b,
a=a,
core_attn_out=core_attn_out,
attn_metadata=attn_metadata,
virtual_engine=forward_context.virtual_engine,
)
has_initial_state = attn_metadata.has_initial_state
spec_query_start_loc = attn_metadata.spec_query_start_loc
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
@@ -946,6 +964,55 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
else:
core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)
def _forward_core_decode_non_spec(
self,
mixed_qkv: torch.Tensor,
b: torch.Tensor,
a: torch.Tensor,
core_attn_out: torch.Tensor,
attn_metadata: GDNAttentionMetadata,
virtual_engine: int,
):
"""
Core attention computation with a packed non-spec decode fast path.
"""
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
self_kv_cache = self.kv_cache[virtual_engine]
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
num_actual_tokens = attn_metadata.num_actual_tokens
mixed_qkv = mixed_qkv[:num_actual_tokens]
b = b[:num_actual_tokens]
a = a[:num_actual_tokens]
conv_weights = self.conv1d.weight.view(
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
)
mixed_qkv_non_spec = causal_conv1d_update(
mixed_qkv,
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=non_spec_state_indices_tensor[:num_actual_tokens],
validate_data=False,
)
out_buf = core_attn_out[:num_actual_tokens].unsqueeze(1)
fused_recurrent_gated_delta_rule_packed_decode(
mixed_qkv=mixed_qkv_non_spec,
a=a,
b=b,
A_log=self.A_log,
dt_bias=self.dt_bias,
scale=self.head_k_dim**-0.5,
initial_state=ssm_state,
out=out_buf,
ssm_state_indices=non_spec_state_indices_tensor[:num_actual_tokens],
use_qk_l2norm_in_kernel=True,
)
return
class Qwen3NextAttention(nn.Module):
def __init__(