feat(rocm-support): support mamba2 on rocm (#18565)
Signed-off-by: Islam Almersawi <islam.almersawi@openinnovation.ai> Co-authored-by: Islam Almersawi <islam.almersawi@openinnovation.ai>
This commit is contained in:
@@ -5,10 +5,9 @@ from dataclasses import dataclass
|
||||
import torch
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.attention.backends.placeholder_attn import (
|
||||
PlaceholderAttentionMetadata)
|
||||
from vllm.attention.backends.xformers import XFormersMetadata
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -23,6 +22,21 @@ class Mamba2Metadata:
|
||||
chunk_offsets: torch.Tensor
|
||||
|
||||
|
||||
def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
|
||||
"""Returns the appropriate metadata classes for the current platform."""
|
||||
if current_platform.is_rocm():
|
||||
from vllm.attention.backends.rocm_flash_attn import (
|
||||
ROCmFlashAttentionMetadata)
|
||||
return (ROCmFlashAttentionMetadata, PlaceholderAttentionMetadata)
|
||||
elif current_platform.is_cuda():
|
||||
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.attention.backends.xformers import XFormersMetadata
|
||||
return (FlashAttentionMetadata, XFormersMetadata,
|
||||
PlaceholderAttentionMetadata)
|
||||
raise ValueError(
|
||||
f"Unsupported platform for Mamba2: {current_platform.device_type}")
|
||||
|
||||
|
||||
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
|
||||
chunk_size: int,
|
||||
total_seqlens: int):
|
||||
@@ -78,9 +92,8 @@ def prepare_mamba2_metadata(
|
||||
|
||||
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only
|
||||
if num_prefills > 0:
|
||||
if (isinstance(attn_metadata,
|
||||
(FlashAttentionMetadata, XFormersMetadata,
|
||||
PlaceholderAttentionMetadata))
|
||||
attn_metadata_instances = get_platform_metadata_classes()
|
||||
if (isinstance(attn_metadata, attn_metadata_instances)
|
||||
and attn_metadata.context_lens_tensor is not None):
|
||||
has_initial_states = \
|
||||
attn_metadata.context_lens_tensor[:num_prefills] > 0 #[batch,]
|
||||
|
||||
Reference in New Issue
Block a user