[Perf][Kernel] Add faster topKperRow decode kernel for DeepSeek-V3.2 sparse attention (#33680)

Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com>
Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
Roberto L. Castro
2026-02-10 16:29:52 +01:00
committed by GitHub
parent 82e11973cc
commit afdce12c89
8 changed files with 554 additions and 12 deletions

View File

@@ -126,6 +126,15 @@ def sparse_attn_indexer(
topk_tokens,
)
# Compute lengths from row spans
# lengths = (chunk.cu_seqlen_ke - chunk.cu_seqlen_ks).to(torch.int32)
# torch.ops._C.large_context_topk(
# logits,
# topk_indices,
# lengths,
# chunk.cu_seqlen_ks, # row_starts
# )
if has_decode:
decode_metadata = attn_metadata.decode
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
@@ -162,18 +171,37 @@ def sparse_attn_indexer(
)
num_rows = logits.shape[0]
topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]
torch.ops._C.top_k_per_row_decode(
logits,
next_n,
decode_metadata.seq_lens,
topk_indices,
num_rows,
logits.stride(0),
logits.stride(1),
topk_tokens,
)
if decode_metadata.use_large_context_topk:
if next_n == 1:
lengths = decode_metadata.seq_lens
else:
# (bs,) -> (bs, 1) + (next_n,) -> (bs, next_n) -> (bs * next_n,)
lengths = (
decode_metadata.seq_lens.unsqueeze(1)
- next_n
+ 1
+ decode_metadata.offsets
).flatten()
torch.ops._C.large_context_topk(
logits,
topk_indices,
lengths,
None,
)
else:
torch.ops._C.top_k_per_row_decode(
logits,
next_n,
decode_metadata.seq_lens,
topk_indices,
num_rows,
logits.stride(0),
logits.stride(1),
topk_tokens,
)
if decode_metadata.requires_padding:
# if padded, we need to unpack