[ROCm] Add AMD GPU support on Deepseek v3.2 and SparseMLA (#26670)

Signed-off-by: ganyi <ygan@amd.com>
This commit is contained in:
Pleaplusone
2025-11-20 18:54:01 +08:00
committed by GitHub
parent 6eb745d9bd
commit 06c20c9904
9 changed files with 583 additions and 15 deletions

View File

@@ -594,6 +594,7 @@ def sparse_attn_indexer(
) -> torch.Tensor:
# careful! this will be None in dummy run
attn_metadata = get_forward_context().attn_metadata
fp8_dtype = current_platform.fp8_dtype()
# assert isinstance(attn_metadata, dict)
if not isinstance(attn_metadata, dict):
return sparse_attn_indexer_fake(
@@ -633,7 +634,7 @@ def sparse_attn_indexer(
k_fp8 = torch.empty(
[chunk.total_seq_lens, head_dim],
device=k.device,
dtype=torch.float8_e4m3fn,
dtype=fp8_dtype,
)
k_scale = torch.empty(
[chunk.total_seq_lens, 4],
@@ -647,7 +648,12 @@ def sparse_attn_indexer(
chunk.block_table,
chunk.cu_seq_lens,
)
logits = fp8_mqa_logits(
fp8_mqa_logits_func = fp8_mqa_logits
if current_platform.is_rocm():
from vllm.attention.ops.rocm_aiter_mla_sparse import rocm_fp8_mqa_logits
fp8_mqa_logits_func = rocm_fp8_mqa_logits
logits = fp8_mqa_logits_func(
q_fp8[chunk.token_start : chunk.token_end],
(k_fp8, k_scale.view(torch.float32)),
weights[chunk.token_start : chunk.token_end],
@@ -692,7 +698,14 @@ def sparse_attn_indexer(
next_n = padded_q_fp8_decode_tokens.shape[1]
assert batch_size == decode_metadata.seq_lens.shape[0]
num_padded_tokens = batch_size * next_n
logits = fp8_paged_mqa_logits(
fp8_paged_mqa_logits_func = fp8_paged_mqa_logits
if current_platform.is_rocm():
from vllm.attention.ops.rocm_aiter_mla_sparse import (
rocm_fp8_paged_mqa_logits,
)
fp8_paged_mqa_logits_func = rocm_fp8_paged_mqa_logits
logits = fp8_paged_mqa_logits_func(
padded_q_fp8_decode_tokens,
kv_cache,
weights[:num_padded_tokens],
@@ -749,7 +762,8 @@ def sparse_attn_indexer_fake(
_flattened_kv = torch.empty(
[total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8
)
_k_fp8 = _flattened_kv[..., :head_dim].view(torch.float8_e4m3fn).contiguous()
fp8_dtype = current_platform.fp8_dtype()
_k_fp8 = _flattened_kv[..., :head_dim].view(fp8_dtype).contiguous()
_k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous()
return topk_indices_buffer