[ROCm] Utilize persistent MLA kernel from AITER (#36574)
Signed-off-by: Sathish Sanjeevi <sathish.krishnan.p.s@gmail.com>
This commit is contained in:
@@ -18,7 +18,12 @@ from vllm.model_executor.layers.attention.mla_attention import (
|
||||
QueryLenSupport,
|
||||
)
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.attention.backend import AttentionCGSupport, AttentionLayer, MultipleOf
|
||||
from vllm.v1.attention.backend import (
|
||||
AttentionCGSupport,
|
||||
AttentionLayer,
|
||||
CommonAttentionMetadata,
|
||||
MultipleOf,
|
||||
)
|
||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
|
||||
|
||||
@@ -71,8 +76,14 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
|
||||
max_qo_len: int | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
|
||||
pass
|
||||
work_meta_data: torch.Tensor | None = None
|
||||
work_indptr: torch.Tensor | None = None
|
||||
work_info_set: torch.Tensor | None = None
|
||||
reduce_indptr: torch.Tensor | None = None
|
||||
reduce_final_map: torch.Tensor | None = None
|
||||
reduce_partial_map: torch.Tensor | None = None
|
||||
|
||||
|
||||
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
@@ -116,6 +127,55 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
max_num_pages, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
from aiter import dtypes, get_mla_metadata_info_v1
|
||||
|
||||
self._num_attention_heads = vllm_config.model_config.get_num_attention_heads(
|
||||
vllm_config.parallel_config
|
||||
)
|
||||
q_dtype = self.decode_attn_out_dtype
|
||||
kv_cache_dtype_str = getattr(vllm_config.cache_config, "cache_dtype", "auto")
|
||||
if kv_cache_dtype_str in ("fp8", "fp8_e4m3", "fp8_e5m2"):
|
||||
kv_cache_dtype_str = "fp8"
|
||||
else:
|
||||
kv_cache_dtype_str = "bf16"
|
||||
kv_dtype = dtypes.d_dtypes.get(kv_cache_dtype_str, dtypes.bf16)
|
||||
(
|
||||
(work_meta_data_size, work_meta_data_type),
|
||||
(work_indptr_size, work_indptr_type),
|
||||
(work_info_set_size, work_info_set_type),
|
||||
(reduce_indptr_size, reduce_indptr_type),
|
||||
(reduce_final_map_size, reduce_final_map_type),
|
||||
(reduce_partial_map_size, reduce_partial_map_type),
|
||||
) = get_mla_metadata_info_v1(
|
||||
max_num_reqs,
|
||||
1,
|
||||
self._num_attention_heads,
|
||||
q_dtype,
|
||||
kv_dtype,
|
||||
is_sparse=False,
|
||||
fast_mode=True,
|
||||
)
|
||||
self._mla_work_meta_data = torch.empty(
|
||||
work_meta_data_size, dtype=work_meta_data_type, device=device
|
||||
)
|
||||
self._mla_work_indptr = torch.empty(
|
||||
work_indptr_size, dtype=work_indptr_type, device=device
|
||||
)
|
||||
self._mla_work_info_set = torch.empty(
|
||||
work_info_set_size, dtype=work_info_set_type, device=device
|
||||
)
|
||||
self._mla_reduce_indptr = torch.empty(
|
||||
reduce_indptr_size, dtype=reduce_indptr_type, device=device
|
||||
)
|
||||
self._mla_reduce_final_map = torch.empty(
|
||||
reduce_final_map_size, dtype=reduce_final_map_type, device=device
|
||||
)
|
||||
self._mla_reduce_partial_map = torch.empty(
|
||||
reduce_partial_map_size,
|
||||
dtype=reduce_partial_map_type,
|
||||
device=device,
|
||||
)
|
||||
|
||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||
self.paged_kv_indptr = torch.zeros(
|
||||
max_num_reqs + 1, dtype=torch.int32, device=device
|
||||
@@ -184,6 +244,28 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
0, num_reqs + 1, step=1, dtype=torch.int32, device=device
|
||||
)
|
||||
|
||||
from aiter import get_mla_metadata_v1
|
||||
|
||||
get_mla_metadata_v1(
|
||||
qo_indptr,
|
||||
paged_kv_indptr,
|
||||
paged_kv_last_page_len,
|
||||
self._num_attention_heads,
|
||||
1,
|
||||
True,
|
||||
self._mla_work_meta_data,
|
||||
self._mla_work_info_set,
|
||||
self._mla_work_indptr,
|
||||
self._mla_reduce_indptr,
|
||||
self._mla_reduce_final_map,
|
||||
self._mla_reduce_partial_map,
|
||||
page_size=1,
|
||||
kv_granularity=16,
|
||||
max_seqlen_qo=max_qo_len,
|
||||
uni_seqlen_qo=max_qo_len,
|
||||
fast_mode=True,
|
||||
)
|
||||
|
||||
attn_metadata = AiterMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
seq_lens=seq_lens_device,
|
||||
@@ -198,6 +280,23 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> AiterMLAMetadata:
|
||||
attn_metadata = super().build(
|
||||
common_prefix_len, common_attn_metadata, fast_build
|
||||
)
|
||||
attn_metadata.work_meta_data = self._mla_work_meta_data
|
||||
attn_metadata.work_indptr = self._mla_work_indptr
|
||||
attn_metadata.work_info_set = self._mla_work_info_set
|
||||
attn_metadata.reduce_indptr = self._mla_reduce_indptr
|
||||
attn_metadata.reduce_final_map = self._mla_reduce_final_map
|
||||
attn_metadata.reduce_partial_map = self._mla_reduce_partial_map
|
||||
return attn_metadata
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _copy_page_indices_kernel(
|
||||
@@ -338,6 +437,12 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
attn_metadata.decode.paged_kv_last_page_len,
|
||||
q_scale=layer._q_scale,
|
||||
kv_scale=layer._k_scale,
|
||||
work_meta_data=attn_metadata.work_meta_data,
|
||||
work_indptr=attn_metadata.work_indptr,
|
||||
work_info_set=attn_metadata.work_info_set,
|
||||
reduce_indptr=attn_metadata.reduce_indptr,
|
||||
reduce_final_map=attn_metadata.reduce_final_map,
|
||||
reduce_partial_map=attn_metadata.reduce_partial_map,
|
||||
)
|
||||
|
||||
if self._needs_head_repeat:
|
||||
|
||||
Reference in New Issue
Block a user