[Kernel] Add fused_sigmoid_gating_delta_rule_update kernel for Qwen3 Next (#35777)
Signed-off-by: Xin Yang <xyangx@amazon.com>
This commit is contained in:
@@ -34,7 +34,7 @@ from vllm.model_executor.layers.fla.ops import (
|
||||
chunk_gated_delta_rule as fla_chunk_gated_delta_rule,
|
||||
)
|
||||
from vllm.model_executor.layers.fla.ops import (
|
||||
fused_recurrent_gated_delta_rule,
|
||||
fused_sigmoid_gating_delta_rule_update,
|
||||
)
|
||||
from vllm.model_executor.layers.fla.ops.chunk import l2norm_fwd
|
||||
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
|
||||
@@ -731,41 +731,40 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
mixed_qkv_non_spec
|
||||
)
|
||||
|
||||
g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias)
|
||||
|
||||
if spec_sequence_masks is not None:
|
||||
if attn_metadata.num_prefills == 0 and attn_metadata.num_decodes == 0:
|
||||
g_spec = g
|
||||
beta_spec = beta
|
||||
g_non_spec = None
|
||||
beta_non_spec = None
|
||||
else:
|
||||
g_spec = g.index_select(1, spec_token_indx)
|
||||
beta_spec = beta.index_select(1, spec_token_indx)
|
||||
if attn_metadata.num_prefills > 0:
|
||||
g, beta = fused_gdn_gating(self.A_log, a, b, self.dt_bias)
|
||||
if spec_sequence_masks is not None:
|
||||
g_non_spec = g.index_select(1, non_spec_token_indx)
|
||||
beta_non_spec = beta.index_select(1, non_spec_token_indx)
|
||||
else:
|
||||
g_non_spec = g
|
||||
beta_non_spec = beta
|
||||
else:
|
||||
g_spec = None
|
||||
beta_spec = None
|
||||
g_non_spec = g
|
||||
beta_non_spec = beta
|
||||
g_non_spec = None
|
||||
beta_non_spec = None
|
||||
|
||||
# 2. Recurrent attention
|
||||
|
||||
# 2.1: Process the multi-query part
|
||||
if spec_sequence_masks is not None:
|
||||
core_attn_out_spec, last_recurrent_state = fused_recurrent_gated_delta_rule(
|
||||
q=query_spec,
|
||||
k=key_spec,
|
||||
v=value_spec,
|
||||
g=g_spec,
|
||||
beta=beta_spec,
|
||||
initial_state=ssm_state,
|
||||
inplace_final_state=True,
|
||||
cu_seqlens=spec_query_start_loc[: attn_metadata.num_spec_decodes + 1],
|
||||
ssm_state_indices=spec_state_indices_tensor,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
core_attn_out_spec, last_recurrent_state = (
|
||||
fused_sigmoid_gating_delta_rule_update(
|
||||
A_log=self.A_log,
|
||||
a=a,
|
||||
b=b,
|
||||
dt_bias=self.dt_bias,
|
||||
q=query_spec,
|
||||
k=key_spec,
|
||||
v=value_spec,
|
||||
initial_state=ssm_state,
|
||||
inplace_final_state=True,
|
||||
cu_seqlens=spec_query_start_loc[
|
||||
: attn_metadata.num_spec_decodes + 1
|
||||
],
|
||||
ssm_state_indices=spec_state_indices_tensor,
|
||||
num_accepted_tokens=num_accepted_tokens,
|
||||
use_qk_l2norm_in_kernel=True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
core_attn_out_spec, last_recurrent_state = None, None
|
||||
@@ -794,12 +793,14 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
)
|
||||
elif attn_metadata.num_decodes > 0:
|
||||
core_attn_out_non_spec, last_recurrent_state = (
|
||||
fused_recurrent_gated_delta_rule(
|
||||
fused_sigmoid_gating_delta_rule_update(
|
||||
A_log=self.A_log,
|
||||
a=a,
|
||||
b=b,
|
||||
dt_bias=self.dt_bias,
|
||||
q=query_non_spec,
|
||||
k=key_non_spec,
|
||||
v=value_non_spec,
|
||||
g=g_non_spec,
|
||||
beta=beta_non_spec,
|
||||
initial_state=ssm_state,
|
||||
inplace_final_state=True,
|
||||
cu_seqlens=non_spec_query_start_loc[
|
||||
|
||||
Reference in New Issue
Block a user