[PERF] Qwen3-next MTP speedup (change bool mask indexing to index_select / index_copy to reduce d2h) (#26437)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
This commit is contained in:
@@ -45,7 +45,7 @@ def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]
|
||||
"""
|
||||
|
||||
cache_entries: tuple[tuple | None, dict | None, Any] = []
|
||||
cache_size = 4
|
||||
cache_size = 8
|
||||
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
|
||||
@@ -423,7 +423,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
(query, key),
|
||||
)
|
||||
value = rearrange(value, "l (h d) -> 1 l h d", d=self.head_v_dim)
|
||||
return query, key, value
|
||||
return query.contiguous(), key.contiguous(), value.contiguous()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -455,7 +455,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
spec_query_start_loc = attn_metadata.spec_query_start_loc
|
||||
non_spec_query_start_loc = attn_metadata.non_spec_query_start_loc
|
||||
spec_sequence_masks = attn_metadata.spec_sequence_masks
|
||||
spec_token_masks = attn_metadata.spec_token_masks
|
||||
spec_token_indx = attn_metadata.spec_token_indx
|
||||
non_spec_token_indx = attn_metadata.non_spec_token_indx
|
||||
spec_state_indices_tensor = attn_metadata.spec_state_indices_tensor # noqa: E501
|
||||
non_spec_state_indices_tensor = attn_metadata.non_spec_state_indices_tensor # noqa: E501
|
||||
self_kv_cache = self.kv_cache[forward_context.virtual_engine]
|
||||
@@ -463,8 +464,6 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
ssm_state = self_kv_cache[1]
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
num_accepted_tokens = attn_metadata.num_accepted_tokens
|
||||
if spec_token_masks is not None:
|
||||
spec_token_masks = spec_token_masks[:num_actual_tokens]
|
||||
|
||||
# 1. Set up dimensions for reshapes later
|
||||
projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states[:num_actual_tokens])
|
||||
@@ -487,8 +486,8 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
mixed_qkv_spec = mixed_qkv
|
||||
mixed_qkv_non_spec = None
|
||||
else:
|
||||
mixed_qkv_spec = mixed_qkv[spec_token_masks]
|
||||
mixed_qkv_non_spec = mixed_qkv[~spec_token_masks]
|
||||
mixed_qkv_spec = mixed_qkv.index_select(0, spec_token_indx)
|
||||
mixed_qkv_non_spec = mixed_qkv.index_select(0, non_spec_token_indx)
|
||||
else:
|
||||
mixed_qkv_spec = None
|
||||
mixed_qkv_non_spec = mixed_qkv
|
||||
@@ -558,10 +557,10 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
g_non_spec = None
|
||||
beta_non_spec = None
|
||||
else:
|
||||
g_spec = g[:, spec_token_masks]
|
||||
beta_spec = beta[:, spec_token_masks]
|
||||
g_non_spec = g[:, ~spec_token_masks]
|
||||
beta_non_spec = beta[:, ~spec_token_masks]
|
||||
g_spec = g.index_select(1, spec_token_indx)
|
||||
beta_spec = beta.index_select(1, spec_token_indx)
|
||||
g_non_spec = g.index_select(1, non_spec_token_indx)
|
||||
beta_non_spec = beta.index_select(1, non_spec_token_indx)
|
||||
else:
|
||||
g_spec = None
|
||||
beta_spec = None
|
||||
@@ -638,8 +637,9 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
dtype=core_attn_out_non_spec.dtype,
|
||||
device=core_attn_out_non_spec.device,
|
||||
)
|
||||
core_attn_out[:, spec_token_masks] = core_attn_out_spec
|
||||
core_attn_out[:, ~spec_token_masks] = core_attn_out_non_spec
|
||||
core_attn_out.index_copy_(1, spec_token_indx, core_attn_out_spec)
|
||||
core_attn_out.index_copy_(1, non_spec_token_indx, core_attn_out_non_spec)
|
||||
|
||||
elif spec_sequence_masks is not None:
|
||||
core_attn_out = core_attn_out_spec
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user