[Deepseek v3.2] Optimize top_k_per_row (#26763)

Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
This commit is contained in:
Daniel Cámpora
2025-10-21 10:30:07 +02:00
committed by GitHub
parent c3a2c6ac5f
commit 80e9452984
5 changed files with 13 additions and 49 deletions

View File

@@ -39,10 +39,9 @@ def create_row_boundaries(
def compare_top_k_results(
logits: torch.Tensor,
cuda_indices: torch.Tensor,
cuda_values: torch.Tensor,
torch_indices: torch.Tensor,
torch_values: torch.Tensor,
row_starts: torch.Tensor,
row_ends: torch.Tensor,
top_k: int,
@@ -70,8 +69,9 @@ def compare_top_k_results(
continue
# Any difference in elements, compare the values
cuda_row_values = cuda_values[row_idx][:num_valid].cpu()
torch_row_values = torch_values[row_idx][:num_valid].cpu()
logits_row = logits[row_idx]
cuda_row_values = [logits_row[i] for i in cuda_row_indices]
torch_row_values = [logits_row[i] for i in torch_row_indices]
cuda_only_values, torch_only_values = [], []
for idx in cuda_set - torch_set:
@@ -115,7 +115,6 @@ def test_top_k_per_row(
# Create output tensors
indices = torch.empty((num_rows, 2048), dtype=torch.int32, device="cuda")
values = torch.empty((num_rows, 2048), dtype=torch.float32, device="cuda")
# Run CUDA implementation
torch.ops._C.top_k_per_row(
@@ -123,14 +122,13 @@ def test_top_k_per_row(
row_starts,
row_ends,
indices,
values,
num_rows,
logits.stride(0),
logits.stride(1),
)
# Run reference implementation
torch_values, torch_indices = logits.topk(min(top_k, max(row_ends)), dim=-1)
torch_indices = logits.topk(min(top_k, max(row_ends)), dim=-1)[1]
mask_lo = torch_indices >= 0
mask_hi = (torch_indices - (row_ends - row_starts)[:, None]) < 0
mask = mask_lo & mask_hi
@@ -138,5 +136,5 @@ def test_top_k_per_row(
# Compare results
assert compare_top_k_results(
indices, values, torch_indices, torch_values, row_starts, row_ends, top_k
logits, indices, torch_indices, row_starts, row_ends, top_k
), "CUDA top_k_per_row results don't match torch.topk"