[Perf][Kernel] Add faster topKperRow decode kernel for DeepSeek-V3.2 sparse attention (#33680)
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com> Signed-off-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com> Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
committed by
GitHub
parent
82e11973cc
commit
afdce12c89
@@ -126,6 +126,15 @@ def sparse_attn_indexer(
|
||||
topk_tokens,
|
||||
)
|
||||
|
||||
# Compute lengths from row spans
|
||||
# lengths = (chunk.cu_seqlen_ke - chunk.cu_seqlen_ks).to(torch.int32)
|
||||
# torch.ops._C.large_context_topk(
|
||||
# logits,
|
||||
# topk_indices,
|
||||
# lengths,
|
||||
# chunk.cu_seqlen_ks, # row_starts
|
||||
# )
|
||||
|
||||
if has_decode:
|
||||
decode_metadata = attn_metadata.decode
|
||||
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
|
||||
@@ -162,18 +171,37 @@ def sparse_attn_indexer(
|
||||
)
|
||||
|
||||
num_rows = logits.shape[0]
|
||||
|
||||
topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]
|
||||
torch.ops._C.top_k_per_row_decode(
|
||||
logits,
|
||||
next_n,
|
||||
decode_metadata.seq_lens,
|
||||
topk_indices,
|
||||
num_rows,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
topk_tokens,
|
||||
)
|
||||
|
||||
if decode_metadata.use_large_context_topk:
|
||||
if next_n == 1:
|
||||
lengths = decode_metadata.seq_lens
|
||||
else:
|
||||
# (bs,) -> (bs, 1) + (next_n,) -> (bs, next_n) -> (bs * next_n,)
|
||||
lengths = (
|
||||
decode_metadata.seq_lens.unsqueeze(1)
|
||||
- next_n
|
||||
+ 1
|
||||
+ decode_metadata.offsets
|
||||
).flatten()
|
||||
|
||||
torch.ops._C.large_context_topk(
|
||||
logits,
|
||||
topk_indices,
|
||||
lengths,
|
||||
None,
|
||||
)
|
||||
else:
|
||||
torch.ops._C.top_k_per_row_decode(
|
||||
logits,
|
||||
next_n,
|
||||
decode_metadata.seq_lens,
|
||||
topk_indices,
|
||||
num_rows,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
topk_tokens,
|
||||
)
|
||||
|
||||
if decode_metadata.requires_padding:
|
||||
# if padded, we need to unpack
|
||||
|
||||
Reference in New Issue
Block a user