[V1] [Hybrid] Enable piecewise CUDA Graph for mamba layers (#21194)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell
2025-07-19 21:27:21 +02:00
committed by GitHub
parent 9f414a12ad
commit 881e3cbe3b
10 changed files with 100 additions and 31 deletions

View File

@@ -13,7 +13,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.forward_context import get_forward_context
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
@@ -33,6 +33,8 @@ from vllm.model_executor.model_loader.weight_utils import (
LoaderFunction, composed_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.mamba_cache import MambaCacheParams
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionMetadata
# Added by the IBM Team, 2024
@@ -424,14 +426,36 @@ class MambaMixer2(MambaBase, CustomOp):
def forward_native(
self,
hidden_states: torch.Tensor,
conv_state: torch.Tensor,
ssm_state: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
mup_vector: Optional[torch.Tensor] = None,
):
pass
def forward(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
mup_vector: Optional[torch.Tensor] = None,
):
if not envs.VLLM_USE_V1:
CustomOp.forward(self, hidden_states, output, mamba_cache_params,
mamba2_metadata, mup_vector)
else:
torch.ops.vllm.mamba_mixer2(
hidden_states,
output,
self.prefix,
mup_vector,
)
def forward_cuda(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
mup_vector: Optional[torch.Tensor] = None,
@@ -517,6 +541,7 @@ class MambaMixer2(MambaBase, CustomOp):
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
has_prefill = num_prefills > 0
has_decode = num_decodes > 0
num_actual_tokens = num_prefill_tokens + num_decodes
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
# Separate prefill and decode by splitting varlen input
@@ -524,18 +549,18 @@ class MambaMixer2(MambaBase, CustomOp):
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
if envs.VLLM_USE_V1:
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
hidden_states_B_C,
hidden_states_B_C[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0,
)
dt_d, dt_p = torch.split(
dt,
dt[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0,
)
# Split along batch dimension
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor,
state_indices_tensor[:num_actual_tokens],
[num_decodes, num_prefills],
dim=0,
)
@@ -696,11 +721,10 @@ class MambaMixer2(MambaBase, CustomOp):
# GatedRMSNorm internally applying SiLU to the gate
# SiLU is applied internally before normalization, unlike standard
# norm usage
hidden_states = self.norm(hidden_states, gate)
hidden_states = self.norm(hidden_states, gate[:num_actual_tokens])
# 5. Final linear projection
out, _ = self.out_proj(hidden_states)
return out
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
return get_mamba_state_shape(
@@ -712,3 +736,36 @@ class MambaMixer2(MambaBase, CustomOp):
state_size=self.ssm_state_size,
conv_kernel=self.conv_kernel_size,
)
def mamba_mixer2(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
mup_vector: Optional[torch.Tensor] = None,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self.forward_cuda(hidden_states=hidden_states,
output=output,
mamba_cache_params=None,
mamba2_metadata=None,
mup_vector=mup_vector)
def mamba_mixer2_fake(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
mup_vector: Optional[torch.Tensor] = None,
) -> None:
return
direct_register_custom_op(
op_name="mamba_mixer2",
op_func=mamba_mixer2,
mutates_args=["output"],
fake_impl=mamba_mixer2_fake,
dispatch_key=current_platform.dispatch_key,
)