[ROCm] Utilize persistent MLA kernel from AITER (#36574)
Signed-off-by: Sathish Sanjeevi <sathish.krishnan.p.s@gmail.com>
This commit is contained in:
@@ -386,6 +386,12 @@ def _rocm_aiter_mla_decode_fwd_impl(
|
||||
logit_cap: float = 0.0,
|
||||
q_scale: torch.Tensor | None = None,
|
||||
kv_scale: torch.Tensor | None = None,
|
||||
work_meta_data: torch.Tensor | None = None,
|
||||
work_indptr: torch.Tensor | None = None,
|
||||
work_info_set: torch.Tensor | None = None,
|
||||
reduce_indptr: torch.Tensor | None = None,
|
||||
reduce_final_map: torch.Tensor | None = None,
|
||||
reduce_partial_map: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
from aiter.mla import mla_decode_fwd
|
||||
|
||||
@@ -399,6 +405,29 @@ def _rocm_aiter_mla_decode_fwd_impl(
|
||||
kwargs["q_scale"] = q_scale
|
||||
kwargs["kv_scale"] = kv_scale
|
||||
|
||||
if work_meta_data is not None:
|
||||
assert work_indptr is not None, (
|
||||
"work_indptr must be provided with work_meta_data"
|
||||
)
|
||||
assert work_info_set is not None, (
|
||||
"work_info_set must be provided with work_meta_data"
|
||||
)
|
||||
assert reduce_indptr is not None, (
|
||||
"reduce_indptr must be provided with work_meta_data"
|
||||
)
|
||||
assert reduce_final_map is not None, (
|
||||
"reduce_final_map must be provided with work_meta_data"
|
||||
)
|
||||
assert reduce_partial_map is not None, (
|
||||
"reduce_partial_map must be provided with work_meta_data"
|
||||
)
|
||||
kwargs["work_meta_data"] = work_meta_data
|
||||
kwargs["work_indptr"] = work_indptr
|
||||
kwargs["work_info_set"] = work_info_set
|
||||
kwargs["reduce_indptr"] = reduce_indptr
|
||||
kwargs["reduce_final_map"] = reduce_final_map
|
||||
kwargs["reduce_partial_map"] = reduce_partial_map
|
||||
|
||||
mla_decode_fwd(
|
||||
q,
|
||||
kv_buffer.view(-1, 1, 1, q.shape[-1]),
|
||||
@@ -425,6 +454,12 @@ def _rocm_aiter_mla_decode_fwd_fake(
|
||||
logit_cap: float = 0.0,
|
||||
q_scale: torch.Tensor | None = None,
|
||||
kv_scale: torch.Tensor | None = None,
|
||||
work_meta_data: torch.Tensor | None = None,
|
||||
work_indptr: torch.Tensor | None = None,
|
||||
work_info_set: torch.Tensor | None = None,
|
||||
reduce_indptr: torch.Tensor | None = None,
|
||||
reduce_final_map: torch.Tensor | None = None,
|
||||
reduce_partial_map: torch.Tensor | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@@ -1659,6 +1694,12 @@ class rocm_aiter_ops:
|
||||
logit_cap: float = 0.0,
|
||||
q_scale: torch.Tensor | None = None,
|
||||
kv_scale: torch.Tensor | None = None,
|
||||
work_meta_data: torch.Tensor | None = None,
|
||||
work_indptr: torch.Tensor | None = None,
|
||||
work_info_set: torch.Tensor | None = None,
|
||||
reduce_indptr: torch.Tensor | None = None,
|
||||
reduce_final_map: torch.Tensor | None = None,
|
||||
reduce_partial_map: torch.Tensor | None = None,
|
||||
):
|
||||
torch.ops.vllm.rocm_aiter_mla_decode_fwd(
|
||||
q,
|
||||
@@ -1673,6 +1714,12 @@ class rocm_aiter_ops:
|
||||
logit_cap=logit_cap,
|
||||
q_scale=q_scale,
|
||||
kv_scale=kv_scale,
|
||||
work_meta_data=work_meta_data,
|
||||
work_indptr=work_indptr,
|
||||
work_info_set=work_info_set,
|
||||
reduce_indptr=reduce_indptr,
|
||||
reduce_final_map=reduce_final_map,
|
||||
reduce_partial_map=reduce_partial_map,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -18,7 +18,12 @@ from vllm.model_executor.layers.attention.mla_attention import (
|
||||
QueryLenSupport,
|
||||
)
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.attention.backend import AttentionCGSupport, AttentionLayer, MultipleOf
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionCGSupport,
|
||||
AttentionLayer,
|
||||
CommonAttentionMetadata,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
|
||||
@@ -71,8 +76,14 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||
max_qo_len: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
|
||||
pass
|
||||
work_meta_data: torch.Tensor | None = None
|
||||
work_indptr: torch.Tensor | None = None
|
||||
work_info_set: torch.Tensor | None = None
|
||||
reduce_indptr: torch.Tensor | None = None
|
||||
reduce_final_map: torch.Tensor | None = None
|
||||
reduce_partial_map: torch.Tensor | None = None
|
||||
|
||||
|
||||
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
@@ -116,6 +127,55 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
max_num_pages, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
from aiter import dtypes, get_mla_metadata_info_v1
|
||||
|
||||
self._num_attention_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
q_dtype = self.decode_attn_out_dtype
|
||||
kv_cache_dtype_str = getattr(vllm_config.cache_config, "cache_dtype", "auto")
|
||||
if kv_cache_dtype_str in ("fp8", "fp8_e4m3", "fp8_e5m2"):
|
||||
kv_cache_dtype_str = "fp8"
|
||||
else:
|
||||
kv_cache_dtype_str = "bf16"
|
||||
kv_dtype = dtypes.d_dtypes.get(kv_cache_dtype_str, dtypes.bf16)
|
||||
(
|
||||
(work_meta_data_size, work_meta_data_type),
|
||||
(work_indptr_size, work_indptr_type),
|
||||
(work_info_set_size, work_info_set_type),
|
||||
(reduce_indptr_size, reduce_indptr_type),
|
||||
(reduce_final_map_size, reduce_final_map_type),
|
||||
(reduce_partial_map_size, reduce_partial_map_type),
|
||||
) = get_mla_metadata_info_v1(
|
||||
max_num_reqs,
|
||||
1,
|
||||
self._num_attention_heads,
|
||||
q_dtype,
|
||||
kv_dtype,
|
||||
is_sparse=False,
|
||||
fast_mode=True,
|
||||
)
|
||||
self._mla_work_meta_data = torch.empty(
|
||||
work_meta_data_size, dtype=work_meta_data_type, device=device
|
||||
)
|
||||
self._mla_work_indptr = torch.empty(
|
||||
work_indptr_size, dtype=work_indptr_type, device=device
|
||||
)
|
||||
self._mla_work_info_set = torch.empty(
|
||||
work_info_set_size, dtype=work_info_set_type, device=device
|
||||
)
|
||||
self._mla_reduce_indptr = torch.empty(
|
||||
reduce_indptr_size, dtype=reduce_indptr_type, device=device
|
||||
)
|
||||
self._mla_reduce_final_map = torch.empty(
|
||||
reduce_final_map_size, dtype=reduce_final_map_type, device=device
|
||||
)
|
||||
self._mla_reduce_partial_map = torch.empty(
|
||||
reduce_partial_map_size,
|
||||
dtype=reduce_partial_map_type,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
self.paged_kv_indptr = torch.zeros(
|
||||
max_num_reqs + 1, dtype=torch.int32, device=device
|
||||
@@ -184,6 +244,28 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
0, num_reqs + 1, step=1, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
from aiter import get_mla_metadata_v1
|
||||
|
||||
get_mla_metadata_v1(
|
||||
qo_indptr,
|
||||
paged_kv_indptr,
|
||||
paged_kv_last_page_len,
|
||||
self._num_attention_heads,
|
||||
1,
|
||||
True,
|
||||
self._mla_work_meta_data,
|
||||
self._mla_work_info_set,
|
||||
self._mla_work_indptr,
|
||||
self._mla_reduce_indptr,
|
||||
self._mla_reduce_final_map,
|
||||
self._mla_reduce_partial_map,
|
||||
page_size=1,
|
||||
kv_granularity=16,
|
||||
max_seqlen_qo=max_qo_len,
|
||||
uni_seqlen_qo=max_qo_len,
|
||||
fast_mode=True,
|
||||
)
|
||||
|
||||
attn_metadata = AiterMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens_device,
|
||||
@@ -198,6 +280,23 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> AiterMLAMetadata:
|
||||
attn_metadata = super().build(
|
||||
common_prefix_len, common_attn_metadata, fast_build
|
||||
)
|
||||
attn_metadata.work_meta_data = self._mla_work_meta_data
|
||||
attn_metadata.work_indptr = self._mla_work_indptr
|
||||
attn_metadata.work_info_set = self._mla_work_info_set
|
||||
attn_metadata.reduce_indptr = self._mla_reduce_indptr
|
||||
attn_metadata.reduce_final_map = self._mla_reduce_final_map
|
||||
attn_metadata.reduce_partial_map = self._mla_reduce_partial_map
|
||||
return attn_metadata
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _copy_page_indices_kernel(
|
||||
@@ -338,6 +437,12 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
attn_metadata.decode.paged_kv_last_page_len,
|
||||
q_scale=layer._q_scale,
|
||||
kv_scale=layer._k_scale,
|
||||
work_meta_data=attn_metadata.work_meta_data,
|
||||
work_indptr=attn_metadata.work_indptr,
|
||||
work_info_set=attn_metadata.work_info_set,
|
||||
reduce_indptr=attn_metadata.reduce_indptr,
|
||||
reduce_final_map=attn_metadata.reduce_final_map,
|
||||
reduce_partial_map=attn_metadata.reduce_partial_map,
|
||||
)
|
||||
|
||||
if self._needs_head_repeat:
|
||||
|
||||
Reference in New Issue
Block a user