[Performance] Tune Mamba selective scan kernel for B200 (#32873)

Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
danisereb
2026-01-26 15:56:54 +02:00
committed by GitHub
parent 208c56256f
commit f4a0921c9c
2 changed files with 26 additions and 11 deletions

View File

@@ -41,6 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import (
sharded_weight_loader,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
@@ -502,6 +503,9 @@ class MambaMixer2(MambaBase, CustomOp):
dim=-1,
)
# Check if running on Blackwell (SM100+) for kernel tuning
self.is_blackwell = current_platform.is_device_capability_family(100)
def forward_native(
self,
hidden_states: torch.Tensor,
@@ -883,6 +887,7 @@ class MambaMixer2(MambaBase, CustomOp):
state_batch_indices=state_indices_tensor_d_input,
dst_state_batch_indices=state_indices_tensor_d_output,
out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
is_blackwell=self.is_blackwell,
)
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:

View File

@@ -286,6 +286,7 @@ def selective_state_update(
out=None,
num_accepted_tokens=None,
cu_seqlens=None,
is_blackwell=False,
):
"""
Argument:
@@ -391,17 +392,26 @@ def selective_state_update(
if dst_state_batch_indices is not None
else (0, 0)
)
# We don't want autotune since it will overwrite the state
# We instead tune by hand.
BLOCK_SIZE_M, num_warps = (
(32, 4)
if dstate <= 16
else (
(16, 4)
if dstate <= 32
else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8))))
)
)
# We don't want autotune since it will overwrite the state.
# We instead tune by hand based on dstate.
# Default
BLOCK_SIZE_M, num_warps = 4, 8
if dstate <= 16:
BLOCK_SIZE_M, num_warps = 32, 4
elif dstate <= 32:
BLOCK_SIZE_M, num_warps = 16, 4
elif dstate <= 64:
BLOCK_SIZE_M, num_warps = 8, 4
else:
# dstate > 64
if is_blackwell:
# Optimized for B200 with dstate>64
BLOCK_SIZE_M, num_warps = 32, 8
elif dstate <= 128:
BLOCK_SIZE_M, num_warps = 4, 4
tie_hdim = (
A.stride(-1) == 0
and A.stride(-2) == 0