diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 7af5e02c2..a620495b7 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -41,6 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import ( sharded_weight_loader, ) from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform from vllm.utils.torch_utils import direct_register_custom_op from vllm.v1.attention.backend import AttentionMetadata from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata @@ -502,6 +503,9 @@ class MambaMixer2(MambaBase, CustomOp): dim=-1, ) + # Check if running on Blackwell (SM100+) for kernel tuning + self.is_blackwell = current_platform.is_device_capability_family(100) + def forward_native( self, hidden_states: torch.Tensor, @@ -883,6 +887,7 @@ class MambaMixer2(MambaBase, CustomOp): state_batch_indices=state_indices_tensor_d_input, dst_state_batch_indices=state_indices_tensor_d_output, out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim), + is_blackwell=self.is_blackwell, ) def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 628ad970c..a0df65f90 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -286,6 +286,7 @@ def selective_state_update( out=None, num_accepted_tokens=None, cu_seqlens=None, + is_blackwell=False, ): """ Argument: @@ -391,17 +392,26 @@ def selective_state_update( if dst_state_batch_indices is not None else (0, 0) ) - # We don't want autotune since it will overwrite the state - # We instead tune by hand. - BLOCK_SIZE_M, num_warps = ( - (32, 4) - if dstate <= 16 - else ( - (16, 4) - if dstate <= 32 - else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8)))) - ) - ) + # We don't want autotune since it will overwrite the state. + # We instead tune by hand based on dstate. + + # Default + BLOCK_SIZE_M, num_warps = 4, 8 + + if dstate <= 16: + BLOCK_SIZE_M, num_warps = 32, 4 + elif dstate <= 32: + BLOCK_SIZE_M, num_warps = 16, 4 + elif dstate <= 64: + BLOCK_SIZE_M, num_warps = 8, 4 + else: + # dstate > 64 + if is_blackwell: + # Optimized for B200 with dstate>64 + BLOCK_SIZE_M, num_warps = 32, 8 + elif dstate <= 128: + BLOCK_SIZE_M, num_warps = 4, 4 + tie_hdim = ( A.stride(-1) == 0 and A.stride(-2) == 0