[Perf][Kernel] Add faster topKperRow decode kernel for DeepSeek-V3.2 sparse attention (#33680)
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com> Signed-off-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com> Co-authored-by: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
committed by
GitHub
parent
82e11973cc
commit
afdce12c89
@@ -275,3 +275,114 @@ def test_top_k_per_row_decode_large_vocab_size(clean_logits: bool) -> None:
|
||||
_run_top_k_per_row_decode_test(
|
||||
top_k, batch_size, next_n, vocab_size, clean_logits, data_generation
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda(), reason="This test requires CUDA")
|
||||
@pytest.mark.parametrize("clean_logits", [True, False])
|
||||
@torch.inference_mode()
|
||||
def test_deepseek_hybrid_topk(clean_logits: bool) -> None:
|
||||
torch.set_default_device("cuda:0")
|
||||
|
||||
top_k = 2048
|
||||
|
||||
# Test case 1: Short sequences (< 8192)
|
||||
batch_size_short = 4
|
||||
next_n = 1
|
||||
num_rows_short = batch_size_short * next_n
|
||||
|
||||
# Create sequences with max length < 8192
|
||||
seq_lens_short = torch.randint(
|
||||
4000, 8000, (batch_size_short,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
row_starts_short = torch.zeros(num_rows_short, dtype=torch.int32, device="cuda")
|
||||
row_indices_short = torch.arange(num_rows_short, device="cuda") // next_n
|
||||
next_n_offset_short = torch.arange(num_rows_short, device="cuda") % next_n
|
||||
row_ends_short = (
|
||||
seq_lens_short[row_indices_short] - next_n + next_n_offset_short + 1
|
||||
)
|
||||
|
||||
logits_short = create_random_logits(
|
||||
row_starts_short, row_ends_short, torch.float32, 42, clean_logits, "random"
|
||||
)
|
||||
|
||||
indices_vllm = torch.empty(
|
||||
(num_rows_short, top_k), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
# Use vllm's kernel for short sequences
|
||||
torch.ops._C.top_k_per_row_decode(
|
||||
logits_short,
|
||||
next_n,
|
||||
seq_lens_short,
|
||||
indices_vllm,
|
||||
num_rows_short,
|
||||
logits_short.stride(0),
|
||||
logits_short.stride(1),
|
||||
top_k,
|
||||
)
|
||||
|
||||
# Test case 2: Long sequences (>= 8192) - should use large_context_topk kernel
|
||||
batch_size_long = 4
|
||||
num_rows_long = batch_size_long * next_n
|
||||
|
||||
# Create sequences with max length >= 8192
|
||||
seq_lens_long = torch.randint(
|
||||
8192, 16384, (batch_size_long,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
row_starts_long = torch.zeros(num_rows_long, dtype=torch.int32, device="cuda")
|
||||
row_indices_long = torch.arange(num_rows_long, device="cuda") // next_n
|
||||
next_n_offset_long = torch.arange(num_rows_long, device="cuda") % next_n
|
||||
row_ends_long = seq_lens_long[row_indices_long] - next_n + next_n_offset_long + 1
|
||||
|
||||
logits_long = create_random_logits(
|
||||
row_starts_long, row_ends_long, torch.float32, 43, clean_logits, "random"
|
||||
)
|
||||
|
||||
indices = torch.empty((num_rows_long, top_k), dtype=torch.int32, device="cuda")
|
||||
|
||||
# Use large_context_topk kernel for long sequences
|
||||
if next_n == 1:
|
||||
lengths = seq_lens_long
|
||||
else:
|
||||
offsets = torch.arange(next_n, device=logits_long.device, dtype=torch.int32)
|
||||
lengths = (seq_lens_long.unsqueeze(1) - next_n + 1 + offsets).flatten()
|
||||
|
||||
torch.ops._C.large_context_topk(
|
||||
logits_long,
|
||||
indices,
|
||||
lengths,
|
||||
None,
|
||||
)
|
||||
|
||||
torch_indices_short = torch.empty(
|
||||
(num_rows_short, top_k), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
for i in range(num_rows_short):
|
||||
row_end = int(row_ends_short[i])
|
||||
k_i = min(top_k, row_end)
|
||||
idx = logits_short[i, :row_end].topk(k_i, dim=-1)[1]
|
||||
torch_indices_short[i, :k_i] = idx
|
||||
|
||||
assert compare_top_k_results(
|
||||
logits_short,
|
||||
indices_vllm,
|
||||
torch_indices_short,
|
||||
row_starts_short,
|
||||
row_ends_short,
|
||||
top_k,
|
||||
), "top_k_per_row_decode kernel (short sequences) doesn't match torch.topk"
|
||||
|
||||
torch_indices_long = torch.empty(
|
||||
(num_rows_long, top_k), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
for i in range(num_rows_long):
|
||||
row_end = int(row_ends_long[i])
|
||||
k_i = min(top_k, row_end)
|
||||
idx = logits_long[i, :row_end].topk(k_i, dim=-1)[1]
|
||||
torch_indices_long[i, :k_i] = idx
|
||||
|
||||
assert compare_top_k_results(
|
||||
logits_long, indices, torch_indices_long, row_starts_long, row_ends_long, top_k
|
||||
), "large_context_topk kernel (long sequences) doesn't match torch.topk"
|
||||
|
||||
Reference in New Issue
Block a user