[V1][Mamba1] - Full CUDA and Piecewise CUDA Graphs Support (#23035)

Signed-off-by: asafg <asafg@ai21.com>
Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com>
Co-authored-by: asafg <asafg@ai21.com>
This commit is contained in:
Asaf Joseph Gardin
2025-08-21 06:08:51 +03:00
committed by GitHub
parent 2461d9e562
commit 3663870c72
9 changed files with 154 additions and 87 deletions

View File

@@ -27,6 +27,8 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update)
from vllm.model_executor.models.mamba_cache import MambaCacheParams
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
@@ -183,22 +185,26 @@ class MambaMixer(MambaBase, CustomOp):
def forward(self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: Optional[MambaCacheParams] = None):
if not envs.VLLM_USE_V1:
return CustomOp.forward(self, hidden_states, mamba_cache_params)
CustomOp.forward(self, hidden_states, output, mamba_cache_params)
else:
return self.forward_cuda(
torch.ops.vllm.mamba_mixer(
hidden_states,
mamba_cache_params,
output,
self.prefix,
)
def forward_native(self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: Optional[MambaCacheParams] = None):
pass
def forward_cuda(self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: Optional[MambaCacheParams] = None):
"""
Run the Mamba-1 SSM pipeline.
@@ -237,6 +243,7 @@ class MambaMixer(MambaBase, CustomOp):
conv_state = self_kv_cache[0].transpose(-1, -2)
ssm_state = self_kv_cache[1]
has_initial_states = mamba1_metadata.has_initial_states
num_padded_decodes = mamba1_metadata.num_padded_decodes
else:
assert isinstance(attn_metadata, AttentionMetadata)
assert mamba_cache_params is not None
@@ -248,6 +255,7 @@ class MambaMixer(MambaBase, CustomOp):
has_initial_states = None
if context_lens_tensor is not None:
has_initial_states = context_lens_tensor > 0
num_padded_decodes = attn_metadata.num_decode_tokens
# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
@@ -267,6 +275,7 @@ class MambaMixer(MambaBase, CustomOp):
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
has_prefill = num_prefill_tokens > 0
has_decode = num_decode_tokens > 0
num_actual_tokens = num_prefill_tokens + num_decode_tokens
prefill_decode_split = split_batch_to_prefill_and_decode(
hidden_states_BC,
@@ -278,6 +287,7 @@ class MambaMixer(MambaBase, CustomOp):
num_decode_tokens,
num_prefills,
num_decodes,
num_padded_decodes,
)
hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p
hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d
@@ -371,7 +381,7 @@ class MambaMixer(MambaBase, CustomOp):
else:
out = self.out_proj(scan_outputs_combined.transpose(-2, -1))[0]
return out
output[:num_actual_tokens] = out
def get_state_dtype(self) -> tuple[torch.dtype]:
assert self.model_config is not None
@@ -421,18 +431,27 @@ def split_batch_to_prefill_and_decode(
num_decode_tokens: int,
num_prefills: int,
num_decodes: int,
num_padded_decodes: int,
) -> PrefillDecodeSplit:
num_actual_tokens = num_prefill_tokens + num_padded_decodes
if envs.VLLM_USE_V1:
# In v1, decode tokens come first, then prefill tokens.
hidden_states_BC_d, hidden_states_BC_p = torch.split(
hidden_states_BC, [num_decode_tokens, num_prefill_tokens], dim=-1)
gate_d, gate_p = torch.split(gate,
[num_decode_tokens, num_prefill_tokens],
hidden_states_BC[..., :num_actual_tokens],
[num_padded_decodes, num_prefill_tokens],
dim=-1)
gate_d, gate_p = torch.split(gate[..., :num_actual_tokens],
[num_padded_decodes, num_prefill_tokens],
dim=-1)
# num_padded_decodes accounts for CUDA graph padding when applicable
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor, [num_decodes, num_prefills], dim=0)
state_indices_tensor[:num_padded_decodes + num_prefills],
[num_padded_decodes, num_prefills],
dim=0)
query_start_loc_p = (query_start_loc[-num_prefills - 1:] -
num_decodes if num_prefills > 0 else None)
num_padded_decodes if num_prefills > 0 else None)
has_initial_states_p = has_initial_states[-num_prefills:] if (
has_initial_states is not None and num_prefills > 0) else None
else:
@@ -459,3 +478,32 @@ def split_batch_to_prefill_and_decode(
query_start_loc_p=query_start_loc_p,
has_initial_states_p=has_initial_states_p,
)
def mamba_mixer(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self.forward_cuda(hidden_states=hidden_states,
output=output,
mamba_cache_params=None)
def mamba_mixer_fake(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="mamba_mixer",
op_func=mamba_mixer,
mutates_args=["output"],
fake_impl=mamba_mixer_fake,
dispatch_key=current_platform.dispatch_key,
)