diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index 3bcfbacbb..de97daccf 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -28,11 +28,15 @@ from vllm.distributed import ( ) from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import init_logger +from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.fla.ops import ( - chunk_gated_delta_rule, + chunk_gated_delta_rule as fla_chunk_gated_delta_rule, +) +from vllm.model_executor.layers.fla.ops import ( fused_recurrent_gated_delta_rule, ) +from vllm.model_executor.layers.fla.ops.chunk import l2norm_fwd from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.layernorm import ( GemmaRMSNorm as Qwen3NextRMSNorm, @@ -101,6 +105,113 @@ logger = init_logger(__name__) KVCache = tuple[torch.Tensor, torch.Tensor] +def fi_chunk_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: torch.LongTensor | None = None, + head_first: bool = False, + use_qk_l2norm_in_kernel: bool = True, +): + from flashinfer.gdn_prefill import ( + chunk_gated_delta_rule as chunk_gated_delta_rule_fi, + ) + + if use_qk_l2norm_in_kernel: + q = l2norm_fwd(q) + k = l2norm_fwd(k) + + # use flashinfer implementation + q = q.squeeze(0).contiguous() + k = k.squeeze(0).contiguous() + v = v.squeeze(0).contiguous() + + g = g.squeeze(0).contiguous() + beta = beta.squeeze(0).contiguous() + fi_state = initial_state.to(torch.float32) + fi_g = g.to(torch.float32) + fi_beta = beta.to(torch.float32) + return chunk_gated_delta_rule_fi( + q=q, + k=k, + v=v, + g=torch.exp(fi_g), + beta=fi_beta, + initial_state=fi_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + ) + + +@CustomOp.register("chunk_gated_delta_rule") +class ChunkGatedDeltaRule(CustomOp): + def __init__(self) -> None: + super().__init__() + if current_platform.is_cuda() and current_platform.is_device_capability(90): + logger.info_once( + "Using FlashInfer GDN prefill kernel on CUDA compute capability 90" + ) + self._forward_method = self.forward_cuda + else: + self._forward_method = self.forward_native + + def forward_cuda( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: torch.LongTensor | None = None, + head_first: bool = False, + use_qk_l2norm_in_kernel: bool = True, + ): + return fi_chunk_gated_delta_rule( + q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + head_first=head_first, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + + def forward_native( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: torch.LongTensor | None = None, + head_first: bool = False, + use_qk_l2norm_in_kernel: bool = True, + ): + return fla_chunk_gated_delta_rule( + q=q, + k=k, + v=v, + g=g, + beta=beta, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + head_first=head_first, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + + class Qwen3NextSparseMoeBlock(nn.Module): def __init__(self, vllm_config: VllmConfig, prefix: str = ""): super().__init__() @@ -362,6 +473,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): prefix=f"{prefix}.out_proj", ) + self.chunk_gated_delta_rule = ChunkGatedDeltaRule() + compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") @@ -647,7 +760,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase): ( core_attn_out_non_spec, last_recurrent_state, - ) = chunk_gated_delta_rule( + ) = self.chunk_gated_delta_rule( q=query_non_spec, k=key_non_spec, v=value_non_spec,