[Perf] Optimize top-k search in apply_top_k_top_p_triton sampler (#37225)
Signed-off-by: mgoin <mgoin64@gmail.com>
This commit is contained in:
@@ -67,6 +67,29 @@ _PERCENTILE_TO_STD_TABLE = [
|
||||
# fmt: on
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _update_min_larger_stats(data, above_mask, min_larger, num_min_larger, sentinel):
|
||||
"""Update running (min, count) of values above a pivot across tiles.
|
||||
|
||||
Tracks the smallest value strictly above a pivot and how many times
|
||||
it occurs. Called once per tile per pivot; the running state is
|
||||
carried across tiles via `min_larger` / `num_min_larger`.
|
||||
|
||||
Merge rule:
|
||||
- tile min < running min → replace both
|
||||
- tile min == running min → accumulate count
|
||||
- tile min > running min → keep running values
|
||||
"""
|
||||
tile_min = tl.min(tl.where(above_mask, data, sentinel))
|
||||
tile_eq = above_mask & (tl.abs(data - tile_min) < 1e-9)
|
||||
tile_cnt = tl.sum(tile_eq)
|
||||
is_new = tile_min < min_larger
|
||||
is_same = tl.abs(tile_min - min_larger) < 1e-9
|
||||
num_min_larger = tl.where(is_new, tile_cnt, num_min_larger + tile_cnt * is_same)
|
||||
min_larger = tl.minimum(min_larger, tile_min)
|
||||
return min_larger, num_min_larger
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _topk_topp_kernel(
|
||||
LOGITS,
|
||||
@@ -188,7 +211,10 @@ def _topk_topp_kernel(
|
||||
min_larger_1 = float("inf")
|
||||
num_min_larger_1 = tl.zeros((), dtype=tl.uint32)
|
||||
|
||||
# First pass: Calculate k_pivots_num and min_larger
|
||||
# Single fused pass: compute k_pivots_num,
|
||||
# min_larger, and num_min_larger together to avoid
|
||||
# a second data scan. See _update_min_larger_stats
|
||||
# for the tile-level merge logic.
|
||||
for i in range(0, search_iters):
|
||||
offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(
|
||||
0, BLOCK_SIZE_TRUNC
|
||||
@@ -198,27 +224,24 @@ def _topk_topp_kernel(
|
||||
BUFFER_ROW + offs_n, mask=mask_n_2, other=-float("inf")
|
||||
)
|
||||
|
||||
k_pivots_num_0 += tl.sum(logits_blk2 > k_pivot_0)
|
||||
k_pivots_num_1 += tl.sum(logits_blk2 > k_pivot_1)
|
||||
above_0 = logits_blk2 > k_pivot_0
|
||||
above_1 = logits_blk2 > k_pivot_1
|
||||
k_pivots_num_0 += tl.sum(above_0)
|
||||
k_pivots_num_1 += tl.sum(above_1)
|
||||
|
||||
min_larger_0 = tl.minimum(min_larger_0, tl.min(logits_blk2))
|
||||
min_larger_1 = tl.minimum(min_larger_1, tl.min(logits_blk2))
|
||||
|
||||
# Second pass: Calculate num_min_larger
|
||||
for i in range(0, search_iters):
|
||||
offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(
|
||||
0, BLOCK_SIZE_TRUNC
|
||||
min_larger_0, num_min_larger_0 = _update_min_larger_stats(
|
||||
logits_blk2,
|
||||
above_0,
|
||||
min_larger_0,
|
||||
num_min_larger_0,
|
||||
float("inf"),
|
||||
)
|
||||
mask_n_2 = offs_n < search_range
|
||||
logits_blk2 = tl.load(
|
||||
BUFFER_ROW + offs_n, mask=mask_n_2, other=-float("inf")
|
||||
)
|
||||
|
||||
num_min_larger_0 += tl.sum(
|
||||
tl.abs(logits_blk2 - min_larger_0) < 1e-9
|
||||
)
|
||||
num_min_larger_1 += tl.sum(
|
||||
tl.abs(logits_blk2 - min_larger_1) < 1e-9
|
||||
min_larger_1, num_min_larger_1 = _update_min_larger_stats(
|
||||
logits_blk2,
|
||||
above_1,
|
||||
min_larger_1,
|
||||
num_min_larger_1,
|
||||
float("inf"),
|
||||
)
|
||||
|
||||
# Check if any of the pivots satisfy termination condition
|
||||
@@ -272,7 +295,8 @@ def _topk_topp_kernel(
|
||||
min_larger_1 = float("inf")
|
||||
num_min_larger_1 = tl.zeros((), dtype=tl.uint32)
|
||||
|
||||
# First pass: Calculate k_pivots_num and min_larger
|
||||
# Single fused pass over full vocab (same approach
|
||||
# as the buffer path above).
|
||||
for i in range(0, NUM_TILES):
|
||||
offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask_n = offs_n < VOCAB_SIZE
|
||||
@@ -280,30 +304,24 @@ def _topk_topp_kernel(
|
||||
LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")
|
||||
)
|
||||
|
||||
k_pivots_num_0 += tl.sum(logits_blk2 > k_pivot_0)
|
||||
k_pivots_num_1 += tl.sum(logits_blk2 > k_pivot_1)
|
||||
above_0 = logits_blk2 > k_pivot_0
|
||||
above_1 = logits_blk2 > k_pivot_1
|
||||
k_pivots_num_0 += tl.sum(above_0)
|
||||
k_pivots_num_1 += tl.sum(above_1)
|
||||
|
||||
# Exclude -inf from min_larger to avoid
|
||||
# poisoning the convergence check.
|
||||
finite_blk2 = tl.where(
|
||||
logits_blk2 > -float("inf"), logits_blk2, float("inf")
|
||||
min_larger_0, num_min_larger_0 = _update_min_larger_stats(
|
||||
logits_blk2,
|
||||
above_0,
|
||||
min_larger_0,
|
||||
num_min_larger_0,
|
||||
float("inf"),
|
||||
)
|
||||
min_larger_0 = tl.minimum(min_larger_0, tl.min(finite_blk2))
|
||||
min_larger_1 = tl.minimum(min_larger_1, tl.min(finite_blk2))
|
||||
|
||||
# Second pass: Calculate num_min_larger
|
||||
for i in range(0, NUM_TILES):
|
||||
offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask_n = offs_n < VOCAB_SIZE
|
||||
logits_blk2 = tl.load(
|
||||
LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")
|
||||
)
|
||||
|
||||
num_min_larger_0 += tl.sum(
|
||||
tl.abs(logits_blk2 - min_larger_0) < 1e-9
|
||||
)
|
||||
num_min_larger_1 += tl.sum(
|
||||
tl.abs(logits_blk2 - min_larger_1) < 1e-9
|
||||
min_larger_1, num_min_larger_1 = _update_min_larger_stats(
|
||||
logits_blk2,
|
||||
above_1,
|
||||
min_larger_1,
|
||||
num_min_larger_1,
|
||||
float("inf"),
|
||||
)
|
||||
|
||||
# Check if any of the pivots satisfy termination condition
|
||||
|
||||
Reference in New Issue
Block a user