[ROCm] Utilize persistent MLA kernel from AITER (#36574)

Signed-off-by: Sathish Sanjeevi <sathish.krishnan.p.s@gmail.com>
This commit is contained in:
Sathish Sanjeevi
2026-03-25 12:00:42 -07:00
committed by GitHub
parent 7d6917bef5
commit 978fc18bf0
2 changed files with 154 additions and 2 deletions

View File

@@ -18,7 +18,12 @@ from vllm.model_executor.layers.attention.mla_attention import (
QueryLenSupport,
)
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backend import AttentionCGSupport, AttentionLayer, MultipleOf
from vllm.v1.attention.backend import (
AttentionCGSupport,
AttentionLayer,
CommonAttentionMetadata,
MultipleOf,
)
from vllm.v1.kv_cache_interface import AttentionSpec
@@ -71,8 +76,14 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
max_qo_len: int | None = None
@dataclass
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
pass
work_meta_data: torch.Tensor | None = None
work_indptr: torch.Tensor | None = None
work_info_set: torch.Tensor | None = None
reduce_indptr: torch.Tensor | None = None
reduce_final_map: torch.Tensor | None = None
reduce_partial_map: torch.Tensor | None = None
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
@@ -116,6 +127,55 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
max_num_pages, dtype=torch.int32, device=device
)
from aiter import dtypes, get_mla_metadata_info_v1
self._num_attention_heads = vllm_config.model_config.get_num_attention_heads(
vllm_config.parallel_config
)
q_dtype = self.decode_attn_out_dtype
kv_cache_dtype_str = getattr(vllm_config.cache_config, "cache_dtype", "auto")
if kv_cache_dtype_str in ("fp8", "fp8_e4m3", "fp8_e5m2"):
kv_cache_dtype_str = "fp8"
else:
kv_cache_dtype_str = "bf16"
kv_dtype = dtypes.d_dtypes.get(kv_cache_dtype_str, dtypes.bf16)
(
(work_meta_data_size, work_meta_data_type),
(work_indptr_size, work_indptr_type),
(work_info_set_size, work_info_set_type),
(reduce_indptr_size, reduce_indptr_type),
(reduce_final_map_size, reduce_final_map_type),
(reduce_partial_map_size, reduce_partial_map_type),
) = get_mla_metadata_info_v1(
max_num_reqs,
1,
self._num_attention_heads,
q_dtype,
kv_dtype,
is_sparse=False,
fast_mode=True,
)
self._mla_work_meta_data = torch.empty(
work_meta_data_size, dtype=work_meta_data_type, device=device
)
self._mla_work_indptr = torch.empty(
work_indptr_size, dtype=work_indptr_type, device=device
)
self._mla_work_info_set = torch.empty(
work_info_set_size, dtype=work_info_set_type, device=device
)
self._mla_reduce_indptr = torch.empty(
reduce_indptr_size, dtype=reduce_indptr_type, device=device
)
self._mla_reduce_final_map = torch.empty(
reduce_final_map_size, dtype=reduce_final_map_type, device=device
)
self._mla_reduce_partial_map = torch.empty(
reduce_partial_map_size,
dtype=reduce_partial_map_type,
device=device,
)
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.paged_kv_indptr = torch.zeros(
max_num_reqs + 1, dtype=torch.int32, device=device
@@ -184,6 +244,28 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
0, num_reqs + 1, step=1, dtype=torch.int32, device=device
)
from aiter import get_mla_metadata_v1
get_mla_metadata_v1(
qo_indptr,
paged_kv_indptr,
paged_kv_last_page_len,
self._num_attention_heads,
1,
True,
self._mla_work_meta_data,
self._mla_work_info_set,
self._mla_work_indptr,
self._mla_reduce_indptr,
self._mla_reduce_final_map,
self._mla_reduce_partial_map,
page_size=1,
kv_granularity=16,
max_seqlen_qo=max_qo_len,
uni_seqlen_qo=max_qo_len,
fast_mode=True,
)
attn_metadata = AiterMLADecodeMetadata(
block_table=block_table_tensor,
seq_lens=seq_lens_device,
@@ -198,6 +280,23 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
return attn_metadata
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> AiterMLAMetadata:
attn_metadata = super().build(
common_prefix_len, common_attn_metadata, fast_build
)
attn_metadata.work_meta_data = self._mla_work_meta_data
attn_metadata.work_indptr = self._mla_work_indptr
attn_metadata.work_info_set = self._mla_work_info_set
attn_metadata.reduce_indptr = self._mla_reduce_indptr
attn_metadata.reduce_final_map = self._mla_reduce_final_map
attn_metadata.reduce_partial_map = self._mla_reduce_partial_map
return attn_metadata
@triton.jit
def _copy_page_indices_kernel(
@@ -338,6 +437,12 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
attn_metadata.decode.paged_kv_last_page_len,
q_scale=layer._q_scale,
kv_scale=layer._k_scale,
work_meta_data=attn_metadata.work_meta_data,
work_indptr=attn_metadata.work_indptr,
work_info_set=attn_metadata.work_info_set,
reduce_indptr=attn_metadata.reduce_indptr,
reduce_final_map=attn_metadata.reduce_final_map,
reduce_partial_map=attn_metadata.reduce_partial_map,
)
if self._needs_head_repeat: