[Model] Mamba2 varlen refactor (#21467)
Signed-off-by: Chih-Chieh-Yang <7364402+cyang49@users.noreply.github.com> Co-authored-by: RishiAstra <40644327+RishiAstra@users.noreply.github.com>
This commit is contained in:
@@ -29,7 +29,7 @@ from vllm.model_executor.layers.mamba.ops.layernorm_gated import rms_norm_gated
|
||||
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_state_update)
|
||||
from vllm.model_executor.layers.mamba.ops.ssd_combined import (
|
||||
mamba_chunk_scan_combined)
|
||||
mamba_chunk_scan_combined_varlen)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
LoaderFunction, composed_weight_loader, sharded_weight_loader)
|
||||
@@ -504,6 +504,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
seq_idx_p = attn_metadata.seq_idx_p
|
||||
chunk_indices_p = attn_metadata.chunk_indices_p
|
||||
chunk_offsets_p = attn_metadata.chunk_offsets_p
|
||||
query_start_loc_p = attn_metadata.query_start_loc_p
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states, _ = self.in_proj(hidden_states)
|
||||
@@ -545,6 +546,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
out, _ = self.out_proj(hidden_states)
|
||||
return out
|
||||
|
||||
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
|
||||
num_prefills = attn_metadata.num_prefills # request count
|
||||
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
|
||||
@@ -570,9 +572,6 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
[num_decodes, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
query_start_loc_p = (
|
||||
attn_metadata.query_start_loc[-num_prefills - 1:] -
|
||||
num_decodes if has_prefill else None)
|
||||
|
||||
# Preallocate output tensor to avoid memcpy cost for merging prefill
|
||||
# and decode outputs
|
||||
@@ -620,15 +619,15 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
ssm_state[state_indices_tensor_p], 0)
|
||||
|
||||
# NOTE: final output is an in-place update of out tensor
|
||||
varlen_state = mamba_chunk_scan_combined(
|
||||
hidden_states_p.view(1, num_prefill_tokens,
|
||||
varlen_states = mamba_chunk_scan_combined_varlen(
|
||||
hidden_states_p.view(num_prefill_tokens,
|
||||
self.num_heads // self.tp_size,
|
||||
self.head_dim),
|
||||
dt_p.unsqueeze(0),
|
||||
dt_p,
|
||||
self.A,
|
||||
B_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size,
|
||||
B_p.view(num_prefill_tokens, self.n_groups // self.tp_size,
|
||||
-1),
|
||||
C_p.view(1, num_prefill_tokens, self.n_groups // self.tp_size,
|
||||
C_p.view(num_prefill_tokens, self.n_groups // self.tp_size,
|
||||
-1),
|
||||
chunk_size=chunk_size,
|
||||
D=self.D,
|
||||
@@ -639,17 +638,15 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
chunk_offsets=chunk_offsets_p,
|
||||
cu_seqlens=query_start_loc_p,
|
||||
initial_states=initial_states,
|
||||
return_varlen_states=True,
|
||||
return_final_states=False,
|
||||
dt_softplus=True,
|
||||
dt_limit=(0.0, float("inf")),
|
||||
out=preallocated_ssm_out_p.view(1, num_prefill_tokens, -1,
|
||||
out=preallocated_ssm_out_p.view(num_prefill_tokens, -1,
|
||||
self.head_dim),
|
||||
state_dtype=ssm_state.dtype)
|
||||
|
||||
# update ssm states
|
||||
# - varlen state is a (num_prefills, nheads, headdim, dstate) tensor
|
||||
ssm_state[state_indices_tensor_p] = varlen_state
|
||||
ssm_state[state_indices_tensor_p] = varlen_states
|
||||
|
||||
# Process decode requests
|
||||
if has_decode:
|
||||
|
||||
Reference in New Issue
Block a user