[Performance] Tune Mamba selective scan kernel for B200 (#32873)

Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
danisereb
2026-01-26 15:56:54 +02:00
committed by GitHub
parent 208c56256f
commit f4a0921c9c
2 changed files with 26 additions and 11 deletions

View File

@@ -41,6 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import (
sharded_weight_loader,
)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils.torch_utils import direct_register_custom_op
from vllm.v1.attention.backend import AttentionMetadata
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
@@ -502,6 +503,9 @@ class MambaMixer2(MambaBase, CustomOp):
dim=-1,
)
# Check if running on Blackwell (SM100+) for kernel tuning
self.is_blackwell = current_platform.is_device_capability_family(100)
def forward_native(
self,
hidden_states: torch.Tensor,
@@ -883,6 +887,7 @@ class MambaMixer2(MambaBase, CustomOp):
state_batch_indices=state_indices_tensor_d_input,
dst_state_batch_indices=state_indices_tensor_d_output,
out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
is_blackwell=self.is_blackwell,
)
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: