Add XPU MLA Sparse backend for DeepSeek v3.2 (#33230)
Signed-off-by: Zhang, Wuxun <wuxun.zhang@intel.com>
This commit is contained in:
@@ -135,16 +135,29 @@ def sparse_attn_indexer(
|
||||
topk_indices = topk_indices_buffer[
|
||||
chunk.token_start : chunk.token_end, :topk_tokens
|
||||
]
|
||||
torch.ops._C.top_k_per_row_prefill(
|
||||
logits,
|
||||
chunk.cu_seqlen_ks,
|
||||
chunk.cu_seqlen_ke,
|
||||
topk_indices,
|
||||
num_rows,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
topk_tokens,
|
||||
)
|
||||
|
||||
if current_platform.is_xpu():
|
||||
ops.top_k_per_row_prefill(
|
||||
logits,
|
||||
chunk.cu_seqlen_ks,
|
||||
chunk.cu_seqlen_ke,
|
||||
topk_indices,
|
||||
num_rows,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
topk_tokens,
|
||||
)
|
||||
else:
|
||||
torch.ops._C.top_k_per_row_prefill(
|
||||
logits,
|
||||
chunk.cu_seqlen_ks,
|
||||
chunk.cu_seqlen_ke,
|
||||
topk_indices,
|
||||
num_rows,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
topk_tokens,
|
||||
)
|
||||
|
||||
# Compute lengths from row spans
|
||||
# lengths = (chunk.cu_seqlen_ke - chunk.cu_seqlen_ks).to(torch.int32)
|
||||
@@ -220,16 +233,28 @@ def sparse_attn_indexer(
|
||||
None,
|
||||
)
|
||||
else:
|
||||
torch.ops._C.top_k_per_row_decode(
|
||||
logits,
|
||||
next_n,
|
||||
decode_metadata.seq_lens,
|
||||
topk_indices,
|
||||
num_rows,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
topk_tokens,
|
||||
)
|
||||
if current_platform.is_xpu():
|
||||
ops.top_k_per_row_decode(
|
||||
logits,
|
||||
next_n,
|
||||
decode_metadata.seq_lens,
|
||||
topk_indices,
|
||||
num_rows,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
topk_tokens,
|
||||
)
|
||||
else:
|
||||
torch.ops._C.top_k_per_row_decode(
|
||||
logits,
|
||||
next_n,
|
||||
decode_metadata.seq_lens,
|
||||
topk_indices,
|
||||
num_rows,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
topk_tokens,
|
||||
)
|
||||
|
||||
if decode_metadata.requires_padding:
|
||||
# if padded, we need to unpack
|
||||
@@ -320,14 +345,14 @@ class SparseAttnIndexer(CustomOp):
|
||||
k: torch.Tensor,
|
||||
weights: torch.Tensor,
|
||||
):
|
||||
if current_platform.is_cuda():
|
||||
if current_platform.is_cuda() or current_platform.is_xpu():
|
||||
return self.forward_cuda(hidden_states, q_fp8, k, weights)
|
||||
elif current_platform.is_rocm():
|
||||
return self.forward_hip(hidden_states, q_fp8, k, weights)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"SparseAttnIndexer native forward is only implemented for "
|
||||
"CUDA and ROCm platform."
|
||||
"CUDA, ROCm and XPU platforms."
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
|
||||
Reference in New Issue
Block a user