[Deepseek v3.2] Optimize top_k_per_row (#26763)
Signed-off-by: Daniel Campora <961215+dcampora@users.noreply.github.com>
This commit is contained in:
@@ -39,10 +39,9 @@ def create_row_boundaries(
|
||||
|
||||
|
||||
def compare_top_k_results(
|
||||
logits: torch.Tensor,
|
||||
cuda_indices: torch.Tensor,
|
||||
cuda_values: torch.Tensor,
|
||||
torch_indices: torch.Tensor,
|
||||
torch_values: torch.Tensor,
|
||||
row_starts: torch.Tensor,
|
||||
row_ends: torch.Tensor,
|
||||
top_k: int,
|
||||
@@ -70,8 +69,9 @@ def compare_top_k_results(
|
||||
continue
|
||||
|
||||
# Any difference in elements, compare the values
|
||||
cuda_row_values = cuda_values[row_idx][:num_valid].cpu()
|
||||
torch_row_values = torch_values[row_idx][:num_valid].cpu()
|
||||
logits_row = logits[row_idx]
|
||||
cuda_row_values = [logits_row[i] for i in cuda_row_indices]
|
||||
torch_row_values = [logits_row[i] for i in torch_row_indices]
|
||||
|
||||
cuda_only_values, torch_only_values = [], []
|
||||
for idx in cuda_set - torch_set:
|
||||
@@ -115,7 +115,6 @@ def test_top_k_per_row(
|
||||
|
||||
# Create output tensors
|
||||
indices = torch.empty((num_rows, 2048), dtype=torch.int32, device="cuda")
|
||||
values = torch.empty((num_rows, 2048), dtype=torch.float32, device="cuda")
|
||||
|
||||
# Run CUDA implementation
|
||||
torch.ops._C.top_k_per_row(
|
||||
@@ -123,14 +122,13 @@ def test_top_k_per_row(
|
||||
row_starts,
|
||||
row_ends,
|
||||
indices,
|
||||
values,
|
||||
num_rows,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
)
|
||||
|
||||
# Run reference implementation
|
||||
torch_values, torch_indices = logits.topk(min(top_k, max(row_ends)), dim=-1)
|
||||
torch_indices = logits.topk(min(top_k, max(row_ends)), dim=-1)[1]
|
||||
mask_lo = torch_indices >= 0
|
||||
mask_hi = (torch_indices - (row_ends - row_starts)[:, None]) < 0
|
||||
mask = mask_lo & mask_hi
|
||||
@@ -138,5 +136,5 @@ def test_top_k_per_row(
|
||||
|
||||
# Compare results
|
||||
assert compare_top_k_results(
|
||||
indices, values, torch_indices, torch_values, row_starts, row_ends, top_k
|
||||
logits, indices, torch_indices, row_starts, row_ends, top_k
|
||||
), "CUDA top_k_per_row results don't match torch.topk"
|
||||
|
||||
Reference in New Issue
Block a user