[Perf] add packed recurrent fast path for decode (#36596)
Signed-off-by: hdj <1293066020@qq.com> Co-authored-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
@@ -10,6 +10,7 @@ from einops import rearrange
|
||||
from torch import nn
|
||||
from transformers.activations import ACT2FN
|
||||
|
||||
from vllm import envs
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (
|
||||
CacheConfig,
|
||||
@@ -34,6 +35,7 @@ from vllm.model_executor.layers.fla.ops import (
|
||||
chunk_gated_delta_rule as fla_chunk_gated_delta_rule,
|
||||
)
|
||||
from vllm.model_executor.layers.fla.ops import (
|
||||
fused_recurrent_gated_delta_rule_packed_decode,
|
||||
fused_sigmoid_gating_delta_rule_update,
|
||||
)
|
||||
from vllm.model_executor.layers.fla.ops.chunk import l2norm_fwd
|
||||
@@ -474,6 +476,9 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
)
|
||||
|
||||
self.chunk_gated_delta_rule = ChunkGatedDeltaRule()
|
||||
self.enable_packed_recurrent_decode = (
|
||||
envs.VLLM_ENABLE_FLA_PACKED_RECURRENT_DECODE
|
||||
)
|
||||
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
@@ -747,9 +752,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
a: torch.Tensor,
|
||||
core_attn_out: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Core attention computation (called by custom op).
|
||||
"""
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||
|
||||
@@ -762,6 +764,22 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
assert isinstance(attn_metadata, GDNAttentionMetadata)
|
||||
|
||||
if (
|
||||
self.enable_packed_recurrent_decode
|
||||
and attn_metadata.spec_sequence_masks is None
|
||||
and attn_metadata.num_prefills == 0
|
||||
and attn_metadata.num_decodes > 0
|
||||
):
|
||||
return self._forward_core_decode_non_spec(
|
||||
mixed_qkv=mixed_qkv,
|
||||
b=b,
|
||||
a=a,
|
||||
core_attn_out=core_attn_out,
|
||||
attn_metadata=attn_metadata,
|
||||
virtual_engine=forward_context.virtual_engine,
|
||||
)
|
||||
|
||||
has_initial_state = attn_metadata.has_initial_state
|
||||
spec_query_start_loc = attn_metadata.spec_query_start_loc
|
||||
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
|
||||
@@ -946,6 +964,55 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
else:
|
||||
core_attn_out[:num_actual_tokens] = core_attn_out_non_spec.squeeze(0)
|
||||
|
||||
def _forward_core_decode_non_spec(
|
||||
self,
|
||||
mixed_qkv: torch.Tensor,
|
||||
b: torch.Tensor,
|
||||
a: torch.Tensor,
|
||||
core_attn_out: torch.Tensor,
|
||||
attn_metadata: GDNAttentionMetadata,
|
||||
virtual_engine: int,
|
||||
):
|
||||
"""
|
||||
Core attention computation with a packed non-spec decode fast path.
|
||||
"""
|
||||
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
|
||||
self_kv_cache = self.kv_cache[virtual_engine]
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
|
||||
mixed_qkv = mixed_qkv[:num_actual_tokens]
|
||||
b = b[:num_actual_tokens]
|
||||
a = a[:num_actual_tokens]
|
||||
|
||||
conv_weights = self.conv1d.weight.view(
|
||||
self.conv1d.weight.size(0), self.conv1d.weight.size(2)
|
||||
)
|
||||
mixed_qkv_non_spec = causal_conv1d_update(
|
||||
mixed_qkv,
|
||||
conv_state,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
self.activation,
|
||||
conv_state_indices=non_spec_state_indices_tensor[:num_actual_tokens],
|
||||
validate_data=False,
|
||||
)
|
||||
out_buf = core_attn_out[:num_actual_tokens].unsqueeze(1)
|
||||
fused_recurrent_gated_delta_rule_packed_decode(
|
||||
mixed_qkv=mixed_qkv_non_spec,
|
||||
a=a,
|
||||
b=b,
|
||||
A_log=self.A_log,
|
||||
dt_bias=self.dt_bias,
|
||||
scale=self.head_k_dim**-0.5,
|
||||
initial_state=ssm_state,
|
||||
out=out_buf,
|
||||
ssm_state_indices=non_spec_state_indices_tensor[:num_actual_tokens],
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
class Qwen3NextAttention(nn.Module):
|
||||
def __init__(
|
||||
|
||||
Reference in New Issue
Block a user