[Performance] Tune Mamba selective scan kernel for B200 (#32873)
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com>
This commit is contained in:
@@ -41,6 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
sharded_weight_loader,
|
||||
)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backend import AttentionMetadata
|
||||
from vllm.v1.attention.backends.mamba2_attn import Mamba2AttentionMetadata
|
||||
@@ -502,6 +503,9 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
dim=-1,
|
||||
)
|
||||
|
||||
# Check if running on Blackwell (SM100+) for kernel tuning
|
||||
self.is_blackwell = current_platform.is_device_capability_family(100)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -883,6 +887,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
state_batch_indices=state_indices_tensor_d_input,
|
||||
dst_state_batch_indices=state_indices_tensor_d_output,
|
||||
out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim),
|
||||
is_blackwell=self.is_blackwell,
|
||||
)
|
||||
|
||||
def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]:
|
||||
|
||||
Reference in New Issue
Block a user