[ROCm] Utilize persistent MLA kernel from AITER (#36574)

Signed-off-by: Sathish Sanjeevi <sathish.krishnan.p.s@gmail.com>
This commit is contained in:
Sathish Sanjeevi
2026-03-25 12:00:42 -07:00
committed by GitHub
parent 7d6917bef5
commit 978fc18bf0
2 changed files with 154 additions and 2 deletions

View File

@@ -386,6 +386,12 @@ def _rocm_aiter_mla_decode_fwd_impl(
logit_cap: float = 0.0, logit_cap: float = 0.0,
q_scale: torch.Tensor | None = None, q_scale: torch.Tensor | None = None,
kv_scale: torch.Tensor | None = None, kv_scale: torch.Tensor | None = None,
work_meta_data: torch.Tensor | None = None,
work_indptr: torch.Tensor | None = None,
work_info_set: torch.Tensor | None = None,
reduce_indptr: torch.Tensor | None = None,
reduce_final_map: torch.Tensor | None = None,
reduce_partial_map: torch.Tensor | None = None,
) -> None: ) -> None:
from aiter.mla import mla_decode_fwd from aiter.mla import mla_decode_fwd
@@ -399,6 +405,29 @@ def _rocm_aiter_mla_decode_fwd_impl(
kwargs["q_scale"] = q_scale kwargs["q_scale"] = q_scale
kwargs["kv_scale"] = kv_scale kwargs["kv_scale"] = kv_scale
if work_meta_data is not None:
assert work_indptr is not None, (
"work_indptr must be provided with work_meta_data"
)
assert work_info_set is not None, (
"work_info_set must be provided with work_meta_data"
)
assert reduce_indptr is not None, (
"reduce_indptr must be provided with work_meta_data"
)
assert reduce_final_map is not None, (
"reduce_final_map must be provided with work_meta_data"
)
assert reduce_partial_map is not None, (
"reduce_partial_map must be provided with work_meta_data"
)
kwargs["work_meta_data"] = work_meta_data
kwargs["work_indptr"] = work_indptr
kwargs["work_info_set"] = work_info_set
kwargs["reduce_indptr"] = reduce_indptr
kwargs["reduce_final_map"] = reduce_final_map
kwargs["reduce_partial_map"] = reduce_partial_map
mla_decode_fwd( mla_decode_fwd(
q, q,
kv_buffer.view(-1, 1, 1, q.shape[-1]), kv_buffer.view(-1, 1, 1, q.shape[-1]),
@@ -425,6 +454,12 @@ def _rocm_aiter_mla_decode_fwd_fake(
logit_cap: float = 0.0, logit_cap: float = 0.0,
q_scale: torch.Tensor | None = None, q_scale: torch.Tensor | None = None,
kv_scale: torch.Tensor | None = None, kv_scale: torch.Tensor | None = None,
work_meta_data: torch.Tensor | None = None,
work_indptr: torch.Tensor | None = None,
work_info_set: torch.Tensor | None = None,
reduce_indptr: torch.Tensor | None = None,
reduce_final_map: torch.Tensor | None = None,
reduce_partial_map: torch.Tensor | None = None,
) -> None: ) -> None:
pass pass
@@ -1659,6 +1694,12 @@ class rocm_aiter_ops:
logit_cap: float = 0.0, logit_cap: float = 0.0,
q_scale: torch.Tensor | None = None, q_scale: torch.Tensor | None = None,
kv_scale: torch.Tensor | None = None, kv_scale: torch.Tensor | None = None,
work_meta_data: torch.Tensor | None = None,
work_indptr: torch.Tensor | None = None,
work_info_set: torch.Tensor | None = None,
reduce_indptr: torch.Tensor | None = None,
reduce_final_map: torch.Tensor | None = None,
reduce_partial_map: torch.Tensor | None = None,
): ):
torch.ops.vllm.rocm_aiter_mla_decode_fwd( torch.ops.vllm.rocm_aiter_mla_decode_fwd(
q, q,
@@ -1673,6 +1714,12 @@ class rocm_aiter_ops:
logit_cap=logit_cap, logit_cap=logit_cap,
q_scale=q_scale, q_scale=q_scale,
kv_scale=kv_scale, kv_scale=kv_scale,
work_meta_data=work_meta_data,
work_indptr=work_indptr,
work_info_set=work_info_set,
reduce_indptr=reduce_indptr,
reduce_final_map=reduce_final_map,
reduce_partial_map=reduce_partial_map,
) )
@staticmethod @staticmethod

View File

@@ -18,7 +18,12 @@ from vllm.model_executor.layers.attention.mla_attention import (
QueryLenSupport, QueryLenSupport,
) )
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.v1.attention.backend import AttentionCGSupport, AttentionLayer, MultipleOf from vllm.v1.attention.backend import (
AttentionCGSupport,
AttentionLayer,
CommonAttentionMetadata,
MultipleOf,
)
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
@@ -71,8 +76,14 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
max_qo_len: int | None = None max_qo_len: int | None = None
@dataclass
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]): class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
pass work_meta_data: torch.Tensor | None = None
work_indptr: torch.Tensor | None = None
work_info_set: torch.Tensor | None = None
reduce_indptr: torch.Tensor | None = None
reduce_final_map: torch.Tensor | None = None
reduce_partial_map: torch.Tensor | None = None
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]): class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
@@ -116,6 +127,55 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
max_num_pages, dtype=torch.int32, device=device max_num_pages, dtype=torch.int32, device=device
) )
from aiter import dtypes, get_mla_metadata_info_v1
self._num_attention_heads = vllm_config.model_config.get_num_attention_heads(
vllm_config.parallel_config
)
q_dtype = self.decode_attn_out_dtype
kv_cache_dtype_str = getattr(vllm_config.cache_config, "cache_dtype", "auto")
if kv_cache_dtype_str in ("fp8", "fp8_e4m3", "fp8_e5m2"):
kv_cache_dtype_str = "fp8"
else:
kv_cache_dtype_str = "bf16"
kv_dtype = dtypes.d_dtypes.get(kv_cache_dtype_str, dtypes.bf16)
(
(work_meta_data_size, work_meta_data_type),
(work_indptr_size, work_indptr_type),
(work_info_set_size, work_info_set_type),
(reduce_indptr_size, reduce_indptr_type),
(reduce_final_map_size, reduce_final_map_type),
(reduce_partial_map_size, reduce_partial_map_type),
) = get_mla_metadata_info_v1(
max_num_reqs,
1,
self._num_attention_heads,
q_dtype,
kv_dtype,
is_sparse=False,
fast_mode=True,
)
self._mla_work_meta_data = torch.empty(
work_meta_data_size, dtype=work_meta_data_type, device=device
)
self._mla_work_indptr = torch.empty(
work_indptr_size, dtype=work_indptr_type, device=device
)
self._mla_work_info_set = torch.empty(
work_info_set_size, dtype=work_info_set_type, device=device
)
self._mla_reduce_indptr = torch.empty(
reduce_indptr_size, dtype=reduce_indptr_type, device=device
)
self._mla_reduce_final_map = torch.empty(
reduce_final_map_size, dtype=reduce_final_map_type, device=device
)
self._mla_reduce_partial_map = torch.empty(
reduce_partial_map_size,
dtype=reduce_partial_map_type,
device=device,
)
if self.compilation_config.cudagraph_mode.has_full_cudagraphs(): if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
self.paged_kv_indptr = torch.zeros( self.paged_kv_indptr = torch.zeros(
max_num_reqs + 1, dtype=torch.int32, device=device max_num_reqs + 1, dtype=torch.int32, device=device
@@ -184,6 +244,28 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
0, num_reqs + 1, step=1, dtype=torch.int32, device=device 0, num_reqs + 1, step=1, dtype=torch.int32, device=device
) )
from aiter import get_mla_metadata_v1
get_mla_metadata_v1(
qo_indptr,
paged_kv_indptr,
paged_kv_last_page_len,
self._num_attention_heads,
1,
True,
self._mla_work_meta_data,
self._mla_work_info_set,
self._mla_work_indptr,
self._mla_reduce_indptr,
self._mla_reduce_final_map,
self._mla_reduce_partial_map,
page_size=1,
kv_granularity=16,
max_seqlen_qo=max_qo_len,
uni_seqlen_qo=max_qo_len,
fast_mode=True,
)
attn_metadata = AiterMLADecodeMetadata( attn_metadata = AiterMLADecodeMetadata(
block_table=block_table_tensor, block_table=block_table_tensor,
seq_lens=seq_lens_device, seq_lens=seq_lens_device,
@@ -198,6 +280,23 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
return attn_metadata return attn_metadata
def build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> AiterMLAMetadata:
attn_metadata = super().build(
common_prefix_len, common_attn_metadata, fast_build
)
attn_metadata.work_meta_data = self._mla_work_meta_data
attn_metadata.work_indptr = self._mla_work_indptr
attn_metadata.work_info_set = self._mla_work_info_set
attn_metadata.reduce_indptr = self._mla_reduce_indptr
attn_metadata.reduce_final_map = self._mla_reduce_final_map
attn_metadata.reduce_partial_map = self._mla_reduce_partial_map
return attn_metadata
@triton.jit @triton.jit
def _copy_page_indices_kernel( def _copy_page_indices_kernel(
@@ -338,6 +437,12 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
attn_metadata.decode.paged_kv_last_page_len, attn_metadata.decode.paged_kv_last_page_len,
q_scale=layer._q_scale, q_scale=layer._q_scale,
kv_scale=layer._k_scale, kv_scale=layer._k_scale,
work_meta_data=attn_metadata.work_meta_data,
work_indptr=attn_metadata.work_indptr,
work_info_set=attn_metadata.work_info_set,
reduce_indptr=attn_metadata.reduce_indptr,
reduce_final_map=attn_metadata.reduce_final_map,
reduce_partial_map=attn_metadata.reduce_partial_map,
) )
if self._needs_head_repeat: if self._needs_head_repeat: