[Performance] Tune Mamba selective scan kernel for B200 (#32873)
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
@@ -41,6 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
sharded_weight_loader,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
|
||||
@@ -502,6 +503,9 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# Check if running on Blackwell (SM100+) for kernel tuning
|
||||
self.is_blackwell = current_platform.is_device_capability_family(100)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -883,6 +887,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
state_batch_indices=state_indices_tensor_d_input,
|
||||
dst_state_batch_indices=state_indices_tensor_d_output,
|
||||
out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
|
||||
is_blackwell=self.is_blackwell,
|
||||
)
|
||||
|
||||
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
|
||||
|
||||
@@ -286,6 +286,7 @@ def selective_state_update(
|
||||
out=None,
|
||||
num_accepted_tokens=None,
|
||||
cu_seqlens=None,
|
||||
is_blackwell=False,
|
||||
):
|
||||
"""
|
||||
Argument:
|
||||
@@ -391,17 +392,26 @@ def selective_state_update(
|
||||
if dst_state_batch_indices is not None
|
||||
else (0, 0)
|
||||
)
|
||||
# We don't want autotune since it will overwrite the state
|
||||
# We instead tune by hand.
|
||||
BLOCK_SIZE_M, num_warps = (
|
||||
(32, 4)
|
||||
if dstate <= 16
|
||||
else (
|
||||
(16, 4)
|
||||
if dstate <= 32
|
||||
else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8))))
|
||||
)
|
||||
)
|
||||
# We don't want autotune since it will overwrite the state.
|
||||
# We instead tune by hand based on dstate.
|
||||
|
||||
# Default
|
||||
BLOCK_SIZE_M, num_warps = 4, 8
|
||||
|
||||
if dstate <= 16:
|
||||
BLOCK_SIZE_M, num_warps = 32, 4
|
||||
elif dstate <= 32:
|
||||
BLOCK_SIZE_M, num_warps = 16, 4
|
||||
elif dstate <= 64:
|
||||
BLOCK_SIZE_M, num_warps = 8, 4
|
||||
else:
|
||||
# dstate > 64
|
||||
if is_blackwell:
|
||||
# Optimized for B200 with dstate>64
|
||||
BLOCK_SIZE_M, num_warps = 32, 8
|
||||
elif dstate <= 128:
|
||||
BLOCK_SIZE_M, num_warps = 4, 4
|
||||
|
||||
tie_hdim = (
|
||||
A.stride(-1) == 0
|
||||
and A.stride(-2) == 0
|
||||
|
||||
Reference in New Issue
Block a user