[ROCm] Utilize persistent MLA kernel from AITER (#36574)
Signed-off-by: Sathish Sanjeevi <sathish.krishnan.p.s@gmail.com>
This commit is contained in:
@@ -386,6 +386,12 @@ def _rocm_aiter_mla_decode_fwd_impl(
|
|||||||
logit_cap: float = 0.0,
|
logit_cap: float = 0.0,
|
||||||
q_scale: torch.Tensor | None = None,
|
q_scale: torch.Tensor | None = None,
|
||||||
kv_scale: torch.Tensor | None = None,
|
kv_scale: torch.Tensor | None = None,
|
||||||
|
work_meta_data: torch.Tensor | None = None,
|
||||||
|
work_indptr: torch.Tensor | None = None,
|
||||||
|
work_info_set: torch.Tensor | None = None,
|
||||||
|
reduce_indptr: torch.Tensor | None = None,
|
||||||
|
reduce_final_map: torch.Tensor | None = None,
|
||||||
|
reduce_partial_map: torch.Tensor | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
from aiter.mla import mla_decode_fwd
|
from aiter.mla import mla_decode_fwd
|
||||||
|
|
||||||
@@ -399,6 +405,29 @@ def _rocm_aiter_mla_decode_fwd_impl(
|
|||||||
kwargs["q_scale"] = q_scale
|
kwargs["q_scale"] = q_scale
|
||||||
kwargs["kv_scale"] = kv_scale
|
kwargs["kv_scale"] = kv_scale
|
||||||
|
|
||||||
|
if work_meta_data is not None:
|
||||||
|
assert work_indptr is not None, (
|
||||||
|
"work_indptr must be provided with work_meta_data"
|
||||||
|
)
|
||||||
|
assert work_info_set is not None, (
|
||||||
|
"work_info_set must be provided with work_meta_data"
|
||||||
|
)
|
||||||
|
assert reduce_indptr is not None, (
|
||||||
|
"reduce_indptr must be provided with work_meta_data"
|
||||||
|
)
|
||||||
|
assert reduce_final_map is not None, (
|
||||||
|
"reduce_final_map must be provided with work_meta_data"
|
||||||
|
)
|
||||||
|
assert reduce_partial_map is not None, (
|
||||||
|
"reduce_partial_map must be provided with work_meta_data"
|
||||||
|
)
|
||||||
|
kwargs["work_meta_data"] = work_meta_data
|
||||||
|
kwargs["work_indptr"] = work_indptr
|
||||||
|
kwargs["work_info_set"] = work_info_set
|
||||||
|
kwargs["reduce_indptr"] = reduce_indptr
|
||||||
|
kwargs["reduce_final_map"] = reduce_final_map
|
||||||
|
kwargs["reduce_partial_map"] = reduce_partial_map
|
||||||
|
|
||||||
mla_decode_fwd(
|
mla_decode_fwd(
|
||||||
q,
|
q,
|
||||||
kv_buffer.view(-1, 1, 1, q.shape[-1]),
|
kv_buffer.view(-1, 1, 1, q.shape[-1]),
|
||||||
@@ -425,6 +454,12 @@ def _rocm_aiter_mla_decode_fwd_fake(
|
|||||||
logit_cap: float = 0.0,
|
logit_cap: float = 0.0,
|
||||||
q_scale: torch.Tensor | None = None,
|
q_scale: torch.Tensor | None = None,
|
||||||
kv_scale: torch.Tensor | None = None,
|
kv_scale: torch.Tensor | None = None,
|
||||||
|
work_meta_data: torch.Tensor | None = None,
|
||||||
|
work_indptr: torch.Tensor | None = None,
|
||||||
|
work_info_set: torch.Tensor | None = None,
|
||||||
|
reduce_indptr: torch.Tensor | None = None,
|
||||||
|
reduce_final_map: torch.Tensor | None = None,
|
||||||
|
reduce_partial_map: torch.Tensor | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -1659,6 +1694,12 @@ class rocm_aiter_ops:
|
|||||||
logit_cap: float = 0.0,
|
logit_cap: float = 0.0,
|
||||||
q_scale: torch.Tensor | None = None,
|
q_scale: torch.Tensor | None = None,
|
||||||
kv_scale: torch.Tensor | None = None,
|
kv_scale: torch.Tensor | None = None,
|
||||||
|
work_meta_data: torch.Tensor | None = None,
|
||||||
|
work_indptr: torch.Tensor | None = None,
|
||||||
|
work_info_set: torch.Tensor | None = None,
|
||||||
|
reduce_indptr: torch.Tensor | None = None,
|
||||||
|
reduce_final_map: torch.Tensor | None = None,
|
||||||
|
reduce_partial_map: torch.Tensor | None = None,
|
||||||
):
|
):
|
||||||
torch.ops.vllm.rocm_aiter_mla_decode_fwd(
|
torch.ops.vllm.rocm_aiter_mla_decode_fwd(
|
||||||
q,
|
q,
|
||||||
@@ -1673,6 +1714,12 @@ class rocm_aiter_ops:
|
|||||||
logit_cap=logit_cap,
|
logit_cap=logit_cap,
|
||||||
q_scale=q_scale,
|
q_scale=q_scale,
|
||||||
kv_scale=kv_scale,
|
kv_scale=kv_scale,
|
||||||
|
work_meta_data=work_meta_data,
|
||||||
|
work_indptr=work_indptr,
|
||||||
|
work_info_set=work_info_set,
|
||||||
|
reduce_indptr=reduce_indptr,
|
||||||
|
reduce_final_map=reduce_final_map,
|
||||||
|
reduce_partial_map=reduce_partial_map,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@@ -18,7 +18,12 @@ from vllm.model_executor.layers.attention.mla_attention import (
|
|||||||
QueryLenSupport,
|
QueryLenSupport,
|
||||||
)
|
)
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
from vllm.v1.attention.backend import AttentionCGSupport, AttentionLayer, MultipleOf
|
from vllm.v1.attention.backend import (
|
||||||
|
AttentionCGSupport,
|
||||||
|
AttentionLayer,
|
||||||
|
CommonAttentionMetadata,
|
||||||
|
MultipleOf,
|
||||||
|
)
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
|
|
||||||
@@ -71,8 +76,14 @@ class AiterMLADecodeMetadata(MLACommonDecodeMetadata):
|
|||||||
max_qo_len: int | None = None
|
max_qo_len: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
|
class AiterMLAMetadata(MLACommonMetadata[AiterMLADecodeMetadata]):
|
||||||
pass
|
work_meta_data: torch.Tensor | None = None
|
||||||
|
work_indptr: torch.Tensor | None = None
|
||||||
|
work_info_set: torch.Tensor | None = None
|
||||||
|
reduce_indptr: torch.Tensor | None = None
|
||||||
|
reduce_final_map: torch.Tensor | None = None
|
||||||
|
reduce_partial_map: torch.Tensor | None = None
|
||||||
|
|
||||||
|
|
||||||
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||||
@@ -116,6 +127,55 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
|||||||
max_num_pages, dtype=torch.int32, device=device
|
max_num_pages, dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from aiter import dtypes, get_mla_metadata_info_v1
|
||||||
|
|
||||||
|
self._num_attention_heads = vllm_config.model_config.get_num_attention_heads(
|
||||||
|
vllm_config.parallel_config
|
||||||
|
)
|
||||||
|
q_dtype = self.decode_attn_out_dtype
|
||||||
|
kv_cache_dtype_str = getattr(vllm_config.cache_config, "cache_dtype", "auto")
|
||||||
|
if kv_cache_dtype_str in ("fp8", "fp8_e4m3", "fp8_e5m2"):
|
||||||
|
kv_cache_dtype_str = "fp8"
|
||||||
|
else:
|
||||||
|
kv_cache_dtype_str = "bf16"
|
||||||
|
kv_dtype = dtypes.d_dtypes.get(kv_cache_dtype_str, dtypes.bf16)
|
||||||
|
(
|
||||||
|
(work_meta_data_size, work_meta_data_type),
|
||||||
|
(work_indptr_size, work_indptr_type),
|
||||||
|
(work_info_set_size, work_info_set_type),
|
||||||
|
(reduce_indptr_size, reduce_indptr_type),
|
||||||
|
(reduce_final_map_size, reduce_final_map_type),
|
||||||
|
(reduce_partial_map_size, reduce_partial_map_type),
|
||||||
|
) = get_mla_metadata_info_v1(
|
||||||
|
max_num_reqs,
|
||||||
|
1,
|
||||||
|
self._num_attention_heads,
|
||||||
|
q_dtype,
|
||||||
|
kv_dtype,
|
||||||
|
is_sparse=False,
|
||||||
|
fast_mode=True,
|
||||||
|
)
|
||||||
|
self._mla_work_meta_data = torch.empty(
|
||||||
|
work_meta_data_size, dtype=work_meta_data_type, device=device
|
||||||
|
)
|
||||||
|
self._mla_work_indptr = torch.empty(
|
||||||
|
work_indptr_size, dtype=work_indptr_type, device=device
|
||||||
|
)
|
||||||
|
self._mla_work_info_set = torch.empty(
|
||||||
|
work_info_set_size, dtype=work_info_set_type, device=device
|
||||||
|
)
|
||||||
|
self._mla_reduce_indptr = torch.empty(
|
||||||
|
reduce_indptr_size, dtype=reduce_indptr_type, device=device
|
||||||
|
)
|
||||||
|
self._mla_reduce_final_map = torch.empty(
|
||||||
|
reduce_final_map_size, dtype=reduce_final_map_type, device=device
|
||||||
|
)
|
||||||
|
self._mla_reduce_partial_map = torch.empty(
|
||||||
|
reduce_partial_map_size,
|
||||||
|
dtype=reduce_partial_map_type,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
if self.compilation_config.cudagraph_mode.has_full_cudagraphs():
|
||||||
self.paged_kv_indptr = torch.zeros(
|
self.paged_kv_indptr = torch.zeros(
|
||||||
max_num_reqs + 1, dtype=torch.int32, device=device
|
max_num_reqs + 1, dtype=torch.int32, device=device
|
||||||
@@ -184,6 +244,28 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
|||||||
0, num_reqs + 1, step=1, dtype=torch.int32, device=device
|
0, num_reqs + 1, step=1, dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from aiter import get_mla_metadata_v1
|
||||||
|
|
||||||
|
get_mla_metadata_v1(
|
||||||
|
qo_indptr,
|
||||||
|
paged_kv_indptr,
|
||||||
|
paged_kv_last_page_len,
|
||||||
|
self._num_attention_heads,
|
||||||
|
1,
|
||||||
|
True,
|
||||||
|
self._mla_work_meta_data,
|
||||||
|
self._mla_work_info_set,
|
||||||
|
self._mla_work_indptr,
|
||||||
|
self._mla_reduce_indptr,
|
||||||
|
self._mla_reduce_final_map,
|
||||||
|
self._mla_reduce_partial_map,
|
||||||
|
page_size=1,
|
||||||
|
kv_granularity=16,
|
||||||
|
max_seqlen_qo=max_qo_len,
|
||||||
|
uni_seqlen_qo=max_qo_len,
|
||||||
|
fast_mode=True,
|
||||||
|
)
|
||||||
|
|
||||||
attn_metadata = AiterMLADecodeMetadata(
|
attn_metadata = AiterMLADecodeMetadata(
|
||||||
block_table=block_table_tensor,
|
block_table=block_table_tensor,
|
||||||
seq_lens=seq_lens_device,
|
seq_lens=seq_lens_device,
|
||||||
@@ -198,6 +280,23 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
|||||||
|
|
||||||
return attn_metadata
|
return attn_metadata
|
||||||
|
|
||||||
|
def build(
|
||||||
|
self,
|
||||||
|
common_prefix_len: int,
|
||||||
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
fast_build: bool = False,
|
||||||
|
) -> AiterMLAMetadata:
|
||||||
|
attn_metadata = super().build(
|
||||||
|
common_prefix_len, common_attn_metadata, fast_build
|
||||||
|
)
|
||||||
|
attn_metadata.work_meta_data = self._mla_work_meta_data
|
||||||
|
attn_metadata.work_indptr = self._mla_work_indptr
|
||||||
|
attn_metadata.work_info_set = self._mla_work_info_set
|
||||||
|
attn_metadata.reduce_indptr = self._mla_reduce_indptr
|
||||||
|
attn_metadata.reduce_final_map = self._mla_reduce_final_map
|
||||||
|
attn_metadata.reduce_partial_map = self._mla_reduce_partial_map
|
||||||
|
return attn_metadata
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def _copy_page_indices_kernel(
|
def _copy_page_indices_kernel(
|
||||||
@@ -338,6 +437,12 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
|||||||
attn_metadata.decode.paged_kv_last_page_len,
|
attn_metadata.decode.paged_kv_last_page_len,
|
||||||
q_scale=layer._q_scale,
|
q_scale=layer._q_scale,
|
||||||
kv_scale=layer._k_scale,
|
kv_scale=layer._k_scale,
|
||||||
|
work_meta_data=attn_metadata.work_meta_data,
|
||||||
|
work_indptr=attn_metadata.work_indptr,
|
||||||
|
work_info_set=attn_metadata.work_info_set,
|
||||||
|
reduce_indptr=attn_metadata.reduce_indptr,
|
||||||
|
reduce_final_map=attn_metadata.reduce_final_map,
|
||||||
|
reduce_partial_map=attn_metadata.reduce_partial_map,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._needs_head_repeat:
|
if self._needs_head_repeat:
|
||||||
|
|||||||
Reference in New Issue
Block a user