[Perf][Kernel] Add faster topKperRow decode kernel for DeepSeek-V3.2 sparse attention (#33680)

Signed-off-by: LopezCastroRoberto <rocastro@redhat.com>
Signed-off-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com>
Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
Roberto L. Castro
2026-02-10 16:29:52 +01:00
committed by GitHub
parent 82e11973cc
commit afdce12c89
8 changed files with 554 additions and 12 deletions

View File

@@ -275,3 +275,114 @@ def test_top_k_per_row_decode_large_vocab_size(clean_logits: bool) -> None:
_run_top_k_per_row_decode_test(
top_k, batch_size, next_n, vocab_size, clean_logits, data_generation
)
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
@pytest.mark.parametrize("clean_logits", [True, False])
@torch.inference_mode()
def test_deepseek_hybrid_topk(clean_logits: bool) -> None:
torch.set_default_device("cuda:0")
top_k = 2048
# Test case 1: Short sequences (< 8192)
batch_size_short = 4
next_n = 1
num_rows_short = batch_size_short * next_n
# Create sequences with max length < 8192
seq_lens_short = torch.randint(
4000, 8000, (batch_size_short,), dtype=torch.int32, device="cuda"
)
row_starts_short = torch.zeros(num_rows_short, dtype=torch.int32, device="cuda")
row_indices_short = torch.arange(num_rows_short, device="cuda") // next_n
next_n_offset_short = torch.arange(num_rows_short, device="cuda") % next_n
row_ends_short = (
seq_lens_short[row_indices_short] - next_n + next_n_offset_short + 1
)
logits_short = create_random_logits(
row_starts_short, row_ends_short, torch.float32, 42, clean_logits, "random"
)
indices_vllm = torch.empty(
(num_rows_short, top_k), dtype=torch.int32, device="cuda"
)
# Use vllm's kernel for short sequences
torch.ops._C.top_k_per_row_decode(
logits_short,
next_n,
seq_lens_short,
indices_vllm,
num_rows_short,
logits_short.stride(0),
logits_short.stride(1),
top_k,
)
# Test case 2: Long sequences (>= 8192) - should use large_context_topk kernel
batch_size_long = 4
num_rows_long = batch_size_long * next_n
# Create sequences with max length >= 8192
seq_lens_long = torch.randint(
8192, 16384, (batch_size_long,), dtype=torch.int32, device="cuda"
)
row_starts_long = torch.zeros(num_rows_long, dtype=torch.int32, device="cuda")
row_indices_long = torch.arange(num_rows_long, device="cuda") // next_n
next_n_offset_long = torch.arange(num_rows_long, device="cuda") % next_n
row_ends_long = seq_lens_long[row_indices_long] - next_n + next_n_offset_long + 1
logits_long = create_random_logits(
row_starts_long, row_ends_long, torch.float32, 43, clean_logits, "random"
)
indices = torch.empty((num_rows_long, top_k), dtype=torch.int32, device="cuda")
# Use large_context_topk kernel for long sequences
if next_n == 1:
lengths = seq_lens_long
else:
offsets = torch.arange(next_n, device=logits_long.device, dtype=torch.int32)
lengths = (seq_lens_long.unsqueeze(1) - next_n + 1 + offsets).flatten()
torch.ops._C.large_context_topk(
logits_long,
indices,
lengths,
None,
)
torch_indices_short = torch.empty(
(num_rows_short, top_k), dtype=torch.int32, device="cuda"
)
for i in range(num_rows_short):
row_end = int(row_ends_short[i])
k_i = min(top_k, row_end)
idx = logits_short[i, :row_end].topk(k_i, dim=-1)[1]
torch_indices_short[i, :k_i] = idx
assert compare_top_k_results(
logits_short,
indices_vllm,
torch_indices_short,
row_starts_short,
row_ends_short,
top_k,
), "top_k_per_row_decode kernel (short sequences) doesn't match torch.topk"
torch_indices_long = torch.empty(
(num_rows_long, top_k), dtype=torch.int32, device="cuda"
)
for i in range(num_rows_long):
row_end = int(row_ends_long[i])
k_i = min(top_k, row_end)
idx = logits_long[i, :row_end].topk(k_i, dim=-1)[1]
torch_indices_long[i, :k_i] = idx
assert compare_top_k_results(
logits_long, indices, torch_indices_long, row_starts_long, row_ends_long, top_k
), "large_context_topk kernel (long sequences) doesn't match torch.topk"