feat(rocm-support): support mamba2 on rocm (#18565)

Signed-off-by: Islam Almersawi <islam.almersawi@openinnovation.ai>
Co-authored-by: Islam Almersawi <islam.almersawi@openinnovation.ai>
This commit is contained in:
almersawi
2025-05-27 11:07:53 +04:00
committed by GitHub
parent fc6d0c290f
commit a547aeb828
5 changed files with 60 additions and 49 deletions

View File

@@ -5,10 +5,9 @@ from dataclasses import dataclass
import torch
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionMetadata)
from vllm.attention.backends.xformers import XFormersMetadata
from vllm.platforms import current_platform
@dataclass
@@ -23,6 +22,21 @@ class Mamba2Metadata:
chunk_offsets: torch.Tensor
def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
"""Returns the appropriate metadata classes for the current platform."""
if current_platform.is_rocm():
from vllm.attention.backends.rocm_flash_attn import (
ROCmFlashAttentionMetadata)
return (ROCmFlashAttentionMetadata, PlaceholderAttentionMetadata)
elif current_platform.is_cuda():
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.attention.backends.xformers import XFormersMetadata
return (FlashAttentionMetadata, XFormersMetadata,
PlaceholderAttentionMetadata)
raise ValueError(
f"Unsupported platform for Mamba2: {current_platform.device_type}")
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
chunk_size: int,
total_seqlens: int):
@@ -78,9 +92,8 @@ def prepare_mamba2_metadata(
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
if num_prefills > 0:
if (isinstance(attn_metadata,
(FlashAttentionMetadata, XFormersMetadata,
PlaceholderAttentionMetadata))
attn_metadata_instances = get_platform_metadata_classes()
if (isinstance(attn_metadata, attn_metadata_instances)
and attn_metadata.context_lens_tensor is not None):
has_initial_states = \
attn_metadata.context_lens_tensor[:num_prefills] > 0 #[batch,]