[V1][Mamba1] - Full CUDA and Piecewise CUDA Graphs Support (#23035)
Signed-off-by: asafg <asafg@ai21.com> Signed-off-by: asafg <39553475+Josephasafg@users.noreply.github.com> Co-authored-by: asafg <asafg@ai21.com>
This commit is contained in:
committed by
GitHub
parent
2461d9e562
commit
3663870c72
@@ -27,6 +27,8 @@ from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
|
||||
selective_scan_fn, selective_state_update)
|
||||
from vllm.model_executor.models.mamba_cache import MambaCacheParams
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.mamba1_attn import Mamba1AttentionMetadata
|
||||
|
||||
|
||||
@@ -183,22 +185,26 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mamba_cache_params: Optional[MambaCacheParams] = None):
|
||||
if not envs.VLLM_USE_V1:
|
||||
return CustomOp.forward(self, hidden_states, mamba_cache_params)
|
||||
CustomOp.forward(self, hidden_states, output, mamba_cache_params)
|
||||
else:
|
||||
return self.forward_cuda(
|
||||
torch.ops.vllm.mamba_mixer(
|
||||
hidden_states,
|
||||
mamba_cache_params,
|
||||
output,
|
||||
self.prefix,
|
||||
)
|
||||
|
||||
def forward_native(self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mamba_cache_params: Optional[MambaCacheParams] = None):
|
||||
pass
|
||||
|
||||
def forward_cuda(self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mamba_cache_params: Optional[MambaCacheParams] = None):
|
||||
"""
|
||||
Run the Mamba-1 SSM pipeline.
|
||||
@@ -237,6 +243,7 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
conv_state = self_kv_cache[0].transpose(-1, -2)
|
||||
ssm_state = self_kv_cache[1]
|
||||
has_initial_states = mamba1_metadata.has_initial_states
|
||||
num_padded_decodes = mamba1_metadata.num_padded_decodes
|
||||
else:
|
||||
assert isinstance(attn_metadata, AttentionMetadata)
|
||||
assert mamba_cache_params is not None
|
||||
@@ -248,6 +255,7 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
has_initial_states = None
|
||||
if context_lens_tensor is not None:
|
||||
has_initial_states = context_lens_tensor > 0
|
||||
num_padded_decodes = attn_metadata.num_decode_tokens
|
||||
|
||||
# 1. Gated MLP's linear projection
|
||||
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
|
||||
@@ -267,6 +275,7 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
num_decodes = attn_metadata.num_decode_tokens # token count (=request)
|
||||
has_prefill = num_prefill_tokens > 0
|
||||
has_decode = num_decode_tokens > 0
|
||||
num_actual_tokens = num_prefill_tokens + num_decode_tokens
|
||||
|
||||
prefill_decode_split = split_batch_to_prefill_and_decode(
|
||||
hidden_states_BC,
|
||||
@@ -278,6 +287,7 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
num_decode_tokens,
|
||||
num_prefills,
|
||||
num_decodes,
|
||||
num_padded_decodes,
|
||||
)
|
||||
hidden_states_BC_p = prefill_decode_split.hidden_states_BC_p
|
||||
hidden_states_BC_d = prefill_decode_split.hidden_states_BC_d
|
||||
@@ -371,7 +381,7 @@ class MambaMixer(MambaBase, CustomOp):
|
||||
else:
|
||||
out = self.out_proj(scan_outputs_combined.transpose(-2, -1))[0]
|
||||
|
||||
return out
|
||||
output[:num_actual_tokens] = out
|
||||
|
||||
def get_state_dtype(self) -> tuple[torch.dtype]:
|
||||
assert self.model_config is not None
|
||||
@@ -421,18 +431,27 @@ def split_batch_to_prefill_and_decode(
|
||||
num_decode_tokens: int,
|
||||
num_prefills: int,
|
||||
num_decodes: int,
|
||||
num_padded_decodes: int,
|
||||
) -> PrefillDecodeSplit:
|
||||
num_actual_tokens = num_prefill_tokens + num_padded_decodes
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
# In v1, decode tokens come first, then prefill tokens.
|
||||
hidden_states_BC_d, hidden_states_BC_p = torch.split(
|
||||
hidden_states_BC, [num_decode_tokens, num_prefill_tokens], dim=-1)
|
||||
gate_d, gate_p = torch.split(gate,
|
||||
[num_decode_tokens, num_prefill_tokens],
|
||||
hidden_states_BC[..., :num_actual_tokens],
|
||||
[num_padded_decodes, num_prefill_tokens],
|
||||
dim=-1)
|
||||
gate_d, gate_p = torch.split(gate[..., :num_actual_tokens],
|
||||
[num_padded_decodes, num_prefill_tokens],
|
||||
dim=-1)
|
||||
|
||||
# num_padded_decodes accounts for CUDA graph padding when applicable
|
||||
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||
state_indices_tensor, [num_decodes, num_prefills], dim=0)
|
||||
state_indices_tensor[:num_padded_decodes + num_prefills],
|
||||
[num_padded_decodes, num_prefills],
|
||||
dim=0)
|
||||
query_start_loc_p = (query_start_loc[-num_prefills - 1:] -
|
||||
num_decodes if num_prefills > 0 else None)
|
||||
num_padded_decodes if num_prefills > 0 else None)
|
||||
has_initial_states_p = has_initial_states[-num_prefills:] if (
|
||||
has_initial_states is not None and num_prefills > 0) else None
|
||||
else:
|
||||
@@ -459,3 +478,32 @@ def split_batch_to_prefill_and_decode(
|
||||
query_start_loc_p=query_start_loc_p,
|
||||
has_initial_states_p=has_initial_states_p,
|
||||
)
|
||||
|
||||
|
||||
def mamba_mixer(
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self.forward_cuda(hidden_states=hidden_states,
|
||||
output=output,
|
||||
mamba_cache_params=None)
|
||||
|
||||
|
||||
def mamba_mixer_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="mamba_mixer",
|
||||
op_func=mamba_mixer,
|
||||
mutates_args=["output"],
|
||||
fake_impl=mamba_mixer_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user