[Model] Mamba2 varlen refactor (#21467)

Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com>
Co-authored-by: RishiAstra <40644327+RishiAstra@users.noreply.github.com>
This commit is contained in:
Chih-Chieh Yang
2025-09-26 07:31:14 -04:00
committed by GitHub
parent 633f943e30
commit 2b6b1d7809
10 changed files with 722 additions and 864 deletions

View File

@@ -29,7 +29,7 @@ from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_state_update)
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
mamba_chunk_scan_combined)
mamba_chunk_scan_combined_varlen)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import (
LoaderFunction, composed_weight_loader, sharded_weight_loader)
@@ -504,6 +504,7 @@ class MambaMixer2(MambaBase, CustomOp):
seq_idx_p = attn_metadata.seq_idx_p
chunk_indices_p = attn_metadata.chunk_indices_p
chunk_offsets_p = attn_metadata.chunk_offsets_p
query_start_loc_p = attn_metadata.query_start_loc_p
# 1. Gated MLP's linear projection
projected_states, _ = self.in_proj(hidden_states)
@@ -545,6 +546,7 @@ class MambaMixer2(MambaBase, CustomOp):
out, _ = self.out_proj(hidden_states)
return out
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
num_prefills = attn_metadata.num_prefills # request count
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
@@ -570,9 +572,6 @@ class MambaMixer2(MambaBase, CustomOp):
[num_decodes, num_prefills],
dim=0,
)
query_start_loc_p = (
attn_metadata.query_start_loc[-num_prefills - 1:] -
num_decodes if has_prefill else None)
# Preallocate output tensor to avoid memcpy cost for merging prefill
# and decode outputs
@@ -620,15 +619,15 @@ class MambaMixer2(MambaBase, CustomOp):
ssm_state[state_indices_tensor_p], 0)
# NOTE: final output is an in-place update of out tensor
varlen_state = mamba_chunk_scan_combined(
hidden_states_p.view(1, num_prefill_tokens,
varlen_states = mamba_chunk_scan_combined_varlen(
hidden_states_p.view(num_prefill_tokens,
self.num_heads // self.tp_size,
self.head_dim),
dt_p.unsqueeze(0),
dt_p,
self.A,
B_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size,
B_p.view(num_prefill_tokens, self.n_groups // self.tp_size,
-1),
C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size,
C_p.view(num_prefill_tokens, self.n_groups // self.tp_size,
-1),
chunk_size=chunk_size,
D=self.D,
@@ -639,17 +638,15 @@ class MambaMixer2(MambaBase, CustomOp):
chunk_offsets=chunk_offsets_p,
cu_seqlens=query_start_loc_p,
initial_states=initial_states,
return_varlen_states=True,
return_final_states=False,
dt_softplus=True,
dt_limit=(0.0, float("inf")),
out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
out=preallocated_ssm_out_p.view(num_prefill_tokens, -1,
self.head_dim),
state_dtype=ssm_state.dtype)
# update ssm states
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
ssm_state[state_indices_tensor_p] = varlen_state
ssm_state[state_indices_tensor_p] = varlen_states
# Process decode requests
if has_decode: