Add XPU MLA Sparse backend for DeepSeek v3.2 (#33230)

Signed-off-by: Zhang, Wuxun <wuxun.zhang@intel.com>
This commit is contained in:
Wuxun Zhang
2026-03-11 19:19:15 +08:00
committed by GitHub
parent 40c0461f24
commit e584dce52b
9 changed files with 940 additions and 24 deletions

View File

@@ -135,16 +135,29 @@ def sparse_attn_indexer(
topk_indices = topk_indices_buffer[
chunk.token_start : chunk.token_end, :topk_tokens
]
torch.ops._C.top_k_per_row_prefill(
logits,
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
topk_indices,
num_rows,
logits.stride(0),
logits.stride(1),
topk_tokens,
)
if current_platform.is_xpu():
ops.top_k_per_row_prefill(
logits,
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
topk_indices,
num_rows,
logits.stride(0),
logits.stride(1),
topk_tokens,
)
else:
torch.ops._C.top_k_per_row_prefill(
logits,
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
topk_indices,
num_rows,
logits.stride(0),
logits.stride(1),
topk_tokens,
)
# Compute lengths from row spans
# lengths = (chunk.cu_seqlen_ke - chunk.cu_seqlen_ks).to(torch.int32)
@@ -220,16 +233,28 @@ def sparse_attn_indexer(
None,
)
else:
torch.ops._C.top_k_per_row_decode(
logits,
next_n,
decode_metadata.seq_lens,
topk_indices,
num_rows,
logits.stride(0),
logits.stride(1),
topk_tokens,
)
if current_platform.is_xpu():
ops.top_k_per_row_decode(
logits,
next_n,
decode_metadata.seq_lens,
topk_indices,
num_rows,
logits.stride(0),
logits.stride(1),
topk_tokens,
)
else:
torch.ops._C.top_k_per_row_decode(
logits,
next_n,
decode_metadata.seq_lens,
topk_indices,
num_rows,
logits.stride(0),
logits.stride(1),
topk_tokens,
)
if decode_metadata.requires_padding:
# if padded, we need to unpack
@@ -320,14 +345,14 @@ class SparseAttnIndexer(CustomOp):
k: torch.Tensor,
weights: torch.Tensor,
):
if current_platform.is_cuda():
if current_platform.is_cuda() or current_platform.is_xpu():
return self.forward_cuda(hidden_states, q_fp8, k, weights)
elif current_platform.is_rocm():
return self.forward_hip(hidden_states, q_fp8, k, weights)
else:
raise NotImplementedError(
"SparseAttnIndexer native forward is only implemented for "
"CUDA and ROCm platform."
"CUDA, ROCm and XPU platforms."
)
def forward_cuda(