[Perf] Optimize top-k search in apply_top_k_top_p_triton sampler (#37225)

Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
Michael Goin
2026-03-17 19:35:17 +01:00
committed by GitHub
parent 4ed51308c8
commit 51b2333be1

View File

@@ -67,6 +67,29 @@ _PERCENTILE_TO_STD_TABLE = [
# fmt: on
@triton.jit
def _update_min_larger_stats(data, above_mask, min_larger, num_min_larger, sentinel):
"""Update running (min, count) of values above a pivot across tiles.
Tracks the smallest value strictly above a pivot and how many times
it occurs. Called once per tile per pivot; the running state is
carried across tiles via `min_larger` / `num_min_larger`.
Merge rule:
- tile min < running min → replace both
- tile min == running min → accumulate count
- tile min > running min → keep running values
"""
tile_min = tl.min(tl.where(above_mask, data, sentinel))
tile_eq = above_mask & (tl.abs(data - tile_min) < 1e-9)
tile_cnt = tl.sum(tile_eq)
is_new = tile_min < min_larger
is_same = tl.abs(tile_min - min_larger) < 1e-9
num_min_larger = tl.where(is_new, tile_cnt, num_min_larger + tile_cnt * is_same)
min_larger = tl.minimum(min_larger, tile_min)
return min_larger, num_min_larger
@triton.jit
def _topk_topp_kernel(
LOGITS,
@@ -188,7 +211,10 @@ def _topk_topp_kernel(
min_larger_1 = float("inf")
num_min_larger_1 = tl.zeros((), dtype=tl.uint32)
# First pass: Calculate k_pivots_num and min_larger
# Single fused pass: compute k_pivots_num,
# min_larger, and num_min_larger together to avoid
# a second data scan. See _update_min_larger_stats
# for the tile-level merge logic.
for i in range(0, search_iters):
offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(
0, BLOCK_SIZE_TRUNC
@@ -198,27 +224,24 @@ def _topk_topp_kernel(
BUFFER_ROW + offs_n, mask=mask_n_2, other=-float("inf")
)
k_pivots_num_0 += tl.sum(logits_blk2 > k_pivot_0)
k_pivots_num_1 += tl.sum(logits_blk2 > k_pivot_1)
above_0 = logits_blk2 > k_pivot_0
above_1 = logits_blk2 > k_pivot_1
k_pivots_num_0 += tl.sum(above_0)
k_pivots_num_1 += tl.sum(above_1)
min_larger_0 = tl.minimum(min_larger_0, tl.min(logits_blk2))
min_larger_1 = tl.minimum(min_larger_1, tl.min(logits_blk2))
# Second pass: Calculate num_min_larger
for i in range(0, search_iters):
offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(
0, BLOCK_SIZE_TRUNC
min_larger_0, num_min_larger_0 = _update_min_larger_stats(
logits_blk2,
above_0,
min_larger_0,
num_min_larger_0,
float("inf"),
)
mask_n_2 = offs_n < search_range
logits_blk2 = tl.load(
BUFFER_ROW + offs_n, mask=mask_n_2, other=-float("inf")
)
num_min_larger_0 += tl.sum(
tl.abs(logits_blk2 - min_larger_0) < 1e-9
)
num_min_larger_1 += tl.sum(
tl.abs(logits_blk2 - min_larger_1) < 1e-9
min_larger_1, num_min_larger_1 = _update_min_larger_stats(
logits_blk2,
above_1,
min_larger_1,
num_min_larger_1,
float("inf"),
)
# Check if any of the pivots satisfy termination condition
@@ -272,7 +295,8 @@ def _topk_topp_kernel(
min_larger_1 = float("inf")
num_min_larger_1 = tl.zeros((), dtype=tl.uint32)
# First pass: Calculate k_pivots_num and min_larger
# Single fused pass over full vocab (same approach
# as the buffer path above).
for i in range(0, NUM_TILES):
offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask_n = offs_n < VOCAB_SIZE
@@ -280,30 +304,24 @@ def _topk_topp_kernel(
LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")
)
k_pivots_num_0 += tl.sum(logits_blk2 > k_pivot_0)
k_pivots_num_1 += tl.sum(logits_blk2 > k_pivot_1)
above_0 = logits_blk2 > k_pivot_0
above_1 = logits_blk2 > k_pivot_1
k_pivots_num_0 += tl.sum(above_0)
k_pivots_num_1 += tl.sum(above_1)
# Exclude -inf from min_larger to avoid
# poisoning the convergence check.
finite_blk2 = tl.where(
logits_blk2 > -float("inf"), logits_blk2, float("inf")
min_larger_0, num_min_larger_0 = _update_min_larger_stats(
logits_blk2,
above_0,
min_larger_0,
num_min_larger_0,
float("inf"),
)
min_larger_0 = tl.minimum(min_larger_0, tl.min(finite_blk2))
min_larger_1 = tl.minimum(min_larger_1, tl.min(finite_blk2))
# Second pass: Calculate num_min_larger
for i in range(0, NUM_TILES):
offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask_n = offs_n < VOCAB_SIZE
logits_blk2 = tl.load(
LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")
)
num_min_larger_0 += tl.sum(
tl.abs(logits_blk2 - min_larger_0) < 1e-9
)
num_min_larger_1 += tl.sum(
tl.abs(logits_blk2 - min_larger_1) < 1e-9
min_larger_1, num_min_larger_1 = _update_min_larger_stats(
logits_blk2,
above_1,
min_larger_1,
num_min_larger_1,
float("inf"),
)
# Check if any of the pivots satisfy termination condition