[Kernel] Triton-based Top-k and Top-p sampler kernels (#33538)

Signed-off-by: js_park <cakeng@naver.com>
Signed-off-by: Jongseok Park <37990712+cakeng@users.noreply.github.com>
Signed-off-by: Sunga Kim <sunga.kim@berkeley.edu>
Signed-off-by: Nick Hill <nickhill123@gmail.com>
Co-authored-by: Sunga Kim <sunga.kim@berkeley.edu>
Co-authored-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Jongseok Park
2026-02-17 15:14:30 -08:00
committed by GitHub
parent dc5fa77a4e
commit c656ba3b4d
6 changed files with 2002 additions and 32 deletions

View File

@@ -0,0 +1,471 @@
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Benchmark comparing Triton vs PyTorch sort-based top-k/top-p implementations.
Compares:
- apply_top_k_top_p_triton (Triton binary search)
- apply_top_k_top_p (PyTorch sort-based)
Scenarios:
- top_k only (whole batch, partial batch)
- top_p only (whole batch, partial batch)
- mix of top_k and top_p
"""
import argparse
import gc
from dataclasses import dataclass
import torch
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_pytorch
from vllm.v1.sample.ops.topk_topp_triton import (
apply_top_k_top_p_triton,
reset_buffer_cache,
)
@dataclass
class BenchmarkConfig:
"""Configuration for a benchmark run."""
name: str
batch_size: int
vocab_size: int
# k and p can be tensors or None
k_values: torch.Tensor | None # [batch_size] or None
p_values: torch.Tensor | None # [batch_size] or None
description: str
ops_pct: float = 0.0 # Percentage of ops relative to batch size
def calculate_ops_pct(
k_values: torch.Tensor | None,
p_values: torch.Tensor | None,
vocab_size: int,
batch_size: int,
) -> float:
"""
Calculate the percentage of active top-k and top-p operations.
Returns percentage where 100% = batch_size ops.
E.g., if all rows have both top-k and top-p active, returns 200%.
"""
active_ops = 0
if k_values is not None:
# Count rows where k < vocab_size (active top-k filtering)
active_ops += (k_values < vocab_size).sum().item()
if p_values is not None:
# Count rows where p < 1.0 (active top-p filtering)
active_ops += (p_values < 1.0).sum().item()
return (active_ops / batch_size) * 100 if batch_size > 0 else 0.0
def create_logits(
batch_size: int, vocab_size: int, device: str = "cuda"
) -> torch.Tensor:
"""Create random logits mimicking a realistic LLM distribution.
Uses a Zipf-like probability distribution (rank^-1.1) converted to logits
via log, then randomly permuted per row. This produces a peaked distribution
where a small number of tokens capture most probability mass, similar to
real model outputs.
"""
# Create Zipf-like probabilities: p(rank) ~ rank^(-alpha)
ranks = torch.arange(1, vocab_size + 1, dtype=torch.float32, device=device)
probs = ranks.pow(-1.1)
probs = probs / probs.sum()
# Convert to logits (log-probabilities, unnormalized is fine)
base_logits = probs.log()
# Broadcast to batch and randomly permute each row
logits = base_logits.unsqueeze(0).expand(batch_size, -1).clone()
for i in range(batch_size):
logits[i] = logits[i, torch.randperm(vocab_size, device=device)]
return logits
def measure_memory() -> tuple[int, int]:
"""Return (allocated, reserved) memory in bytes."""
torch.cuda.synchronize()
return torch.cuda.memory_allocated(), torch.cuda.max_memory_allocated()
def reset_memory_stats():
"""Reset peak memory statistics."""
reset_buffer_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
gc.collect()
def benchmark_function(
func,
logits: torch.Tensor,
k: torch.Tensor | None,
p: torch.Tensor | None,
warmup_iters: int = 5,
benchmark_iters: int = 20,
) -> tuple[float, int]:
"""
Benchmark a function and return (avg_time_ms, peak_memory_bytes).
Returns average time in milliseconds and peak memory usage.
"""
# Warmup
for _ in range(warmup_iters):
logits_copy = logits.clone()
func(logits_copy, k, p)
torch.cuda.synchronize()
# Reset memory stats before benchmark
reset_memory_stats()
# Benchmark
start_events = [
torch.cuda.Event(enable_timing=True) for _ in range(benchmark_iters)
]
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(benchmark_iters)]
for i in range(benchmark_iters):
logits_copy = logits.clone()
start_events[i].record()
func(logits_copy, k, p)
end_events[i].record()
torch.cuda.synchronize()
# Calculate timing
times = [
start_events[i].elapsed_time(end_events[i]) for i in range(benchmark_iters)
]
avg_time = sum(times) / len(times)
# Get peak memory
_, peak_memory = measure_memory()
return avg_time, peak_memory
def create_benchmark_configs(
batch_sizes: list[int],
vocab_sizes: list[int],
device: str = "cuda",
) -> list[BenchmarkConfig]:
"""Create all benchmark configurations."""
configs = []
for vocab_size in vocab_sizes:
for batch_size in batch_sizes:
# 1. Top-k only - whole batch (all rows have k < vocab_size)
k_all = torch.full((batch_size,), 50, dtype=torch.int32, device=device)
configs.append(
BenchmarkConfig(
name=f"topk_whole_b{batch_size}_v{vocab_size // 1000}k",
batch_size=batch_size,
vocab_size=vocab_size,
k_values=k_all,
p_values=None,
description=f"Top-k only (whole batch, k=50), "
f"batch={batch_size}, vocab={vocab_size}",
ops_pct=calculate_ops_pct(k_all, None, vocab_size, batch_size),
)
)
# 2. Top-k only - partial batch (half have k=50, half have k=vocab_size)
k_partial = torch.full((batch_size,), 50, dtype=torch.int32, device=device)
k_partial[batch_size // 2 :] = vocab_size # No filtering for second half
configs.append(
BenchmarkConfig(
name=f"topk_partial_b{batch_size}_v{vocab_size // 1000}k",
batch_size=batch_size,
vocab_size=vocab_size,
k_values=k_partial,
p_values=None,
description=f"Top-k only (partial batch, 50% k=50, 50% k=vocab), "
f"batch={batch_size}, vocab={vocab_size}",
ops_pct=calculate_ops_pct(k_partial, None, vocab_size, batch_size),
)
)
# 3. Top-p only - whole batch (all rows have p < 1.0)
p_all = torch.full((batch_size,), 0.9, dtype=torch.float32, device=device)
configs.append(
BenchmarkConfig(
name=f"topp_whole_b{batch_size}_v{vocab_size // 1000}k",
batch_size=batch_size,
vocab_size=vocab_size,
k_values=None,
p_values=p_all,
description=f"Top-p only (whole batch, p=0.9), "
f"batch={batch_size}, vocab={vocab_size}",
ops_pct=calculate_ops_pct(None, p_all, vocab_size, batch_size),
)
)
# 4. Top-p only - partial batch (half have p=0.9, half have p=1.0)
p_partial = torch.full(
(batch_size,), 0.9, dtype=torch.float32, device=device
)
p_partial[batch_size // 2 :] = 1.0 # No filtering for second half
configs.append(
BenchmarkConfig(
name=f"topp_partial_b{batch_size}_v{vocab_size // 1000}k",
batch_size=batch_size,
vocab_size=vocab_size,
k_values=None,
p_values=p_partial,
description=f"Top-p only (partial batch, 50% p=0.9, 50% p=1.0), "
f"batch={batch_size}, vocab={vocab_size}",
ops_pct=calculate_ops_pct(None, p_partial, vocab_size, batch_size),
)
)
# 5. Mix of top-k and top-p (both applied to whole batch)
k_mix = torch.full((batch_size,), 100, dtype=torch.int32, device=device)
p_mix = torch.full((batch_size,), 0.9, dtype=torch.float32, device=device)
configs.append(
BenchmarkConfig(
name=f"topk_topp_whole_b{batch_size}_v{vocab_size // 1000}k",
batch_size=batch_size,
vocab_size=vocab_size,
k_values=k_mix,
p_values=p_mix,
description=f"Top-k + Top-p (whole batch, k=100, p=0.9), "
f"batch={batch_size}, vocab={vocab_size}",
ops_pct=calculate_ops_pct(k_mix, p_mix, vocab_size, batch_size),
)
)
# 6. Mix with partial application (some rows k only, some p only, some both)
k_mixed = torch.full(
(batch_size,), vocab_size, dtype=torch.int32, device=device
)
p_mixed = torch.full((batch_size,), 1.0, dtype=torch.float32, device=device)
# First third: k only
third = batch_size // 3
k_mixed[:third] = 50
# Second third: p only
p_mixed[third : 2 * third] = 0.5
# Last third: both k and p
k_mixed[2 * third :] = 100
p_mixed[2 * third :] = 0.9
configs.append(
BenchmarkConfig(
name=f"mixed_partial_b{batch_size}_v{vocab_size // 1000}k",
batch_size=batch_size,
vocab_size=vocab_size,
k_values=k_mixed,
p_values=p_mixed,
description=f"Mixed partial (1/3 k=50, 1/3 p=0.9, 1/3 both), "
f"batch={batch_size}, vocab={vocab_size}",
ops_pct=calculate_ops_pct(k_mixed, p_mixed, vocab_size, batch_size),
)
)
return configs
def format_memory(bytes_val: int) -> str:
"""Format memory in human-readable form."""
if bytes_val >= 1024**3:
return f"{bytes_val / (1024**3):.2f} GB"
elif bytes_val >= 1024**2:
return f"{bytes_val / (1024**2):.2f} MB"
elif bytes_val >= 1024:
return f"{bytes_val / 1024:.2f} KB"
return f"{bytes_val} B"
def run_benchmark(
configs: list[BenchmarkConfig],
warmup_iters: int = 5,
benchmark_iters: int = 20,
verbose: bool = True,
):
"""Run all benchmarks and print results."""
results = []
print("=" * 100)
print("Top-k/Top-p Benchmark: Triton vs PyTorch Sort-based")
print("=" * 100)
print()
for config in configs:
if verbose:
print(f"Running: {config.description}")
# Create fresh logits for this config
logits = create_logits(config.batch_size, config.vocab_size)
# Benchmark Triton
reset_memory_stats()
triton_time, triton_mem = benchmark_function(
apply_top_k_top_p_triton,
logits,
config.k_values,
config.p_values,
warmup_iters,
benchmark_iters,
)
# Benchmark PyTorch
reset_memory_stats()
pytorch_time, pytorch_mem = benchmark_function(
apply_top_k_top_p_pytorch,
logits,
config.k_values,
config.p_values,
warmup_iters,
benchmark_iters,
)
speedup = pytorch_time / triton_time if triton_time > 0 else float("inf")
mem_ratio = pytorch_mem / triton_mem if triton_mem > 0 else float("inf")
result = {
"config": config,
"triton_time_ms": triton_time,
"pytorch_time_ms": pytorch_time,
"triton_mem": triton_mem,
"pytorch_mem": pytorch_mem,
"speedup": speedup,
"mem_ratio": mem_ratio,
}
results.append(result)
if verbose:
print(f" Triton: {triton_time:.3f} ms, {format_memory(triton_mem)}")
print(f" PyTorch: {pytorch_time:.3f} ms, {format_memory(pytorch_mem)}")
print(f" Speedup: {speedup:.2f}x, Memory ratio: {mem_ratio:.2f}x")
print()
# Clean up
del logits
reset_memory_stats()
return results
def print_summary_table(results: list[dict]):
"""Print a summary table of results."""
print()
print("=" * 130)
print("SUMMARY TABLE")
print("=" * 130)
print()
# Header
header = (
f"{'Scenario':<40} {'Batch':>6} {'Vocab':>7} {'Ops%':>6} "
f"{'Triton (ms)':>12} {'PyTorch (ms)':>13} {'Speedup':>8} "
f"{'Tri Mem':>10} {'Pyt Mem':>10}"
)
print(header)
print("-" * 130)
# Group by scenario type
current_vocab = None
for result in results:
config = result["config"]
# Add separator between vocab sizes
if current_vocab != config.vocab_size:
if current_vocab is not None:
print("-" * 130)
current_vocab = config.vocab_size
scenario = config.name.split("_b")[0] # Extract scenario name
print(
f"{scenario:<40} {config.batch_size:>6} {config.vocab_size:>7} "
f"{config.ops_pct:>5.0f}% "
f"{result['triton_time_ms']:>12.3f} {result['pytorch_time_ms']:>13.3f} "
f"{result['speedup']:>7.2f}x "
f"{format_memory(result['triton_mem']):>10} "
f"{format_memory(result['pytorch_mem']):>10}"
)
print("=" * 130)
def main():
parser = argparse.ArgumentParser(
description="Benchmark Triton vs PyTorch sort-based top-k/top-p implementations"
)
parser.add_argument(
"--batch-sizes",
type=int,
nargs="+",
default=[1, 4, 16, 64, 128, 512, 1024, 2048],
help="Batch sizes to test (default: 1 4 16 64)",
)
parser.add_argument(
"--vocab-sizes",
type=int,
nargs="+",
default=[32768, 131072], # 32k, 128k
help="Vocabulary sizes to test (default: 32768 131072)",
)
parser.add_argument(
"--warmup-iters",
type=int,
default=5,
help="Number of warmup iterations (default: 5)",
)
parser.add_argument(
"--benchmark-iters",
type=int,
default=20,
help="Number of benchmark iterations (default: 20)",
)
parser.add_argument(
"--quiet",
action="store_true",
help="Only print summary table",
)
args = parser.parse_args()
# Print configuration
print(f"Batch sizes: {args.batch_sizes}")
print(f"Vocab sizes: {args.vocab_sizes}")
print(f"Warmup iterations: {args.warmup_iters}")
print(f"Benchmark iterations: {args.benchmark_iters}")
print()
# Check CUDA
if not torch.cuda.is_available():
print("ERROR: CUDA is not available. This benchmark requires a GPU.")
return
device_name = torch.cuda.get_device_name(0)
print(f"GPU: {device_name}")
print()
# Create configs
configs = create_benchmark_configs(
args.batch_sizes,
args.vocab_sizes,
)
# Run benchmarks
results = run_benchmark(
configs,
warmup_iters=args.warmup_iters,
benchmark_iters=args.benchmark_iters,
verbose=not args.quiet,
)
# Print summary
print_summary_table(results)
if __name__ == "__main__":
main()

View File

@@ -145,6 +145,7 @@ async def test_request_cancellation(server: RemoteOpenAIServer):
model=MODEL_NAME,
max_tokens=10000,
extra_body={"min_tokens": 10000},
temperature=0.0,
)
)
tasks.append(task)
@@ -163,7 +164,7 @@ async def test_request_cancellation(server: RemoteOpenAIServer):
# be able to respond to this one within the timeout
client = server.get_async_client(timeout=5)
response = await client.chat.completions.create(
messages=chat_input, model=MODEL_NAME, max_tokens=10
messages=chat_input, model=MODEL_NAME, max_tokens=10, temperature=0.0
)
assert len(response.choices) == 1

View File

@@ -5,8 +5,9 @@ import torch
from torch import Generator
from vllm.platforms import current_platform
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_pytorch
CUDA_DEVICE = "cuda" if current_platform.is_cuda() else None
DEVICE = current_platform.device_type
BATCH_SIZE = 1024
@@ -39,11 +40,11 @@ def test_topk_impl_equivalence():
)
# Top-k only implementation
result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
result1 = apply_top_k_top_p_pytorch(logits=logits.clone(), k=k, p=None)
# Top-p + top-k
no_op_top_p = torch.tensor([1.0])
result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p)
result2 = apply_top_k_top_p_pytorch(logits=logits.clone(), k=k, p=no_op_top_p)
assert torch.allclose(result1, result2)
@@ -98,7 +99,7 @@ def test_flashinfer_sampler():
torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=torch.bool), 1.0
)
python_logits = apply_top_k_top_p(
python_logits = apply_top_k_top_p_pytorch(
logits=logits.clone(),
k=k_values,
p=p_values,
@@ -120,3 +121,451 @@ def test_flashinfer_sampler():
assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), (
"FlashInfer and Python sampling implementations do not match!"
)
# =============================================================================
# Triton kernel tests
# =============================================================================
@pytest.mark.skipif(CUDA_DEVICE is None, reason="CUDA not available")
class TestTritonTopkTopp:
"""Tests for the Triton top-k/top-p kernel."""
@pytest.fixture(autouse=True)
def setup(self):
"""Set up test fixtures."""
torch.set_default_device(CUDA_DEVICE)
self.generator = Generator(device=CUDA_DEVICE).manual_seed(42)
def _compare_results(
self,
logits: torch.Tensor,
k: torch.Tensor | None,
p: torch.Tensor | None,
):
"""Compare Triton kernel results with PyTorch sorting implementation.
For top-k only, we expect exact match.
For top-p (with or without top-k), we allow small differences due to
floating-point precision in probability sum calculations.
"""
from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton
# Clone logits for both implementations
logits_pytorch = logits.clone()
logits_triton = logits.clone().to(torch.float32)
# Apply PyTorch sorting implementation
result_pytorch = apply_top_k_top_p_pytorch(logits_pytorch, k, p)
# Apply Triton kernel
k_i32 = k.to(torch.int32) if k is not None else None
p_f32 = p.to(torch.float32) if p is not None else None
result_triton = apply_top_k_top_p_triton(logits_triton, k_i32, p_f32)
# Compare kept counts per row
pytorch_kept = (result_pytorch != float("-inf")).sum(dim=-1)
triton_kept = (result_triton != float("-inf")).sum(dim=-1)
if p is None:
# Top-k only: expect exact match
assert torch.equal(pytorch_kept, triton_kept), (
f"Top-k mask mismatch: PyTorch kept {pytorch_kept.tolist()}, "
f"Triton kept {triton_kept.tolist()}"
)
else:
# Top-p involved: allow small differences
# Either < 1% of kept values OR < 5 values absolute
max_diff = (pytorch_kept - triton_kept).abs().max().item()
max_kept = pytorch_kept.max().item()
if max_kept > 0 and max_diff > 3:
diff_pct = max_diff / max_kept * 100
assert diff_pct < 0.5, (
f"Top-p mask difference too large: {diff_pct:.2f}% "
f"(max diff {max_diff} values out of {max_kept})"
)
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128, 512, 1024])
@pytest.mark.parametrize("vocab_size", [1024, 32000, 128256])
def test_topk_only(self, batch_size: int, vocab_size: int):
"""Test top-k only (p=None)."""
logits = torch.randn(
batch_size, vocab_size, generator=self.generator, dtype=torch.float32
)
k = torch.randint(
1, min(100, vocab_size), (batch_size,), generator=self.generator
)
# Randomly disable top-k for some rows (~25%)
disable_mask = torch.randint(0, 4, (batch_size,), generator=self.generator) == 0
k.masked_fill_(disable_mask, vocab_size)
self._compare_results(logits, k, p=None)
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128, 512, 1024])
@pytest.mark.parametrize("vocab_size", [1024, 32000, 128256])
def test_topp_only(self, batch_size: int, vocab_size: int):
"""Test top-p only (k=None)."""
logits = torch.randn(
batch_size, vocab_size, generator=self.generator, dtype=torch.float32
)
p = torch.rand(batch_size, generator=self.generator) * 0.9 + 0.1 # [0.1, 1.0]
# Randomly disable top-p for some rows (~25%)
disable_mask = torch.randint(0, 4, (batch_size,), generator=self.generator) == 0
p.masked_fill_(disable_mask, 1.0)
self._compare_results(logits, k=None, p=p)
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128, 512, 1024])
@pytest.mark.parametrize("vocab_size", [1024, 32000, 128256])
def test_topk_and_topp(self, batch_size: int, vocab_size: int):
"""Test combined top-k and top-p."""
logits = torch.randn(
batch_size, vocab_size, generator=self.generator, dtype=torch.float32
)
k = torch.randint(
1, min(100, vocab_size), (batch_size,), generator=self.generator
)
p = torch.rand(batch_size, generator=self.generator) * 0.9 + 0.1 # [0.1, 1.0]
# Randomly disable top-k for some rows (~25%)
disable_k = torch.randint(0, 4, (batch_size,), generator=self.generator) == 0
k.masked_fill_(disable_k, vocab_size)
# Randomly disable top-p for some rows (~25%)
disable_p = torch.randint(0, 4, (batch_size,), generator=self.generator) == 0
p.masked_fill_(disable_p, 1.0)
self._compare_results(logits, k, p)
def test_both_disabled(self):
"""Test when both k and p are None (should be no-op)."""
from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton
logits = torch.randn(32, 1024, generator=self.generator, dtype=torch.float32)
logits_clone = logits.clone()
result = apply_top_k_top_p_triton(logits_clone, k=None, p=None)
assert torch.equal(result, logits), "Should be no-op when both k and p are None"
def test_extreme_k_values(self):
"""Test edge cases for k values."""
batch_size, vocab_size = 16, 1024
logits = torch.randn(
batch_size, vocab_size, generator=self.generator, dtype=torch.float32
)
# k=1 (keep only top 1)
k = torch.ones(batch_size, dtype=torch.int32)
self._compare_results(logits.clone(), k, p=None)
# k=vocab_size (keep all)
k = torch.full((batch_size,), vocab_size, dtype=torch.int32)
self._compare_results(logits.clone(), k, p=None)
# Mixed extreme values
k = torch.tensor([1, vocab_size, 2, vocab_size - 1] * 4, dtype=torch.int32)
self._compare_results(logits.clone(), k, p=None)
def test_extreme_p_values(self):
"""Test edge cases for p values."""
batch_size, vocab_size = 16, 1024
logits = torch.randn(
batch_size, vocab_size, generator=self.generator, dtype=torch.float32
)
# p close to 0 (very restrictive)
p = torch.full((batch_size,), 0.01, dtype=torch.float32)
self._compare_results(logits.clone(), k=None, p=p)
# p=1.0 (keep all)
p = torch.ones(batch_size, dtype=torch.float32)
self._compare_results(logits.clone(), k=None, p=p)
# Mixed values
p = torch.tensor([0.1, 0.5, 0.9, 1.0] * 4, dtype=torch.float32)
self._compare_results(logits.clone(), k=None, p=p)
def test_large_batch(self):
"""Test with a large batch size."""
batch_size, vocab_size = 512, 32000
logits = torch.randn(
batch_size, vocab_size, generator=self.generator, dtype=torch.float32
)
k = torch.randint(1, 50, (batch_size,), generator=self.generator)
p = torch.rand(batch_size, generator=self.generator) * 0.5 + 0.5
self._compare_results(logits, k, p)
# -----------------------------------------------------------------
# Tests for -inf logits (e.g. from grammar / structured output masks)
# -----------------------------------------------------------------
@pytest.mark.parametrize("inf_fraction", [0.5, 0.9, 0.99])
def test_topk_with_neginf_logits(self, inf_fraction: float):
"""Top-k with many -inf logits (simulating grammar bitmask).
The kernel must not produce NaN when most logits are -inf, which
can happen when structured-output grammar masks are applied before
sampling.
"""
from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton
batch_size, vocab_size = 32, 128256
logits = torch.randn(
batch_size, vocab_size, generator=self.generator, dtype=torch.float32
)
# Mask a fraction of logits to -inf.
mask = (
torch.rand(batch_size, vocab_size, generator=self.generator) < inf_fraction
)
logits[mask] = float("-inf")
k = torch.randint(
1, 50, (batch_size,), generator=self.generator, dtype=torch.int32
)
result = apply_top_k_top_p_triton(logits.clone(), k, None)
assert not result.isnan().any(), "NaN found in top-k result with -inf logits"
for i in range(batch_size):
kept = (result[i] > float("-inf")).sum().item()
assert kept <= k[i].item(), f"Row {i}: kept {kept} > k={k[i].item()}"
# At least one value should survive unless the row was all -inf.
finite_in = (logits[i] > float("-inf")).sum().item()
if finite_in > 0:
assert kept > 0, f"Row {i}: no tokens kept despite finite input"
@pytest.mark.parametrize("inf_fraction", [0.5, 0.9, 0.99])
def test_topp_with_neginf_logits(self, inf_fraction: float):
"""Top-p with many -inf logits."""
from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton
batch_size, vocab_size = 32, 128256
logits = torch.randn(
batch_size, vocab_size, generator=self.generator, dtype=torch.float32
)
mask = (
torch.rand(batch_size, vocab_size, generator=self.generator) < inf_fraction
)
logits[mask] = float("-inf")
p = (
torch.rand(batch_size, generator=self.generator, dtype=torch.float32) * 0.9
+ 0.1
)
result = apply_top_k_top_p_triton(logits.clone(), None, p)
assert not result.isnan().any(), "NaN found in top-p result with -inf logits"
for i in range(batch_size):
finite_in = (logits[i] > float("-inf")).sum().item()
kept = (result[i] > float("-inf")).sum().item()
if finite_in > 0:
assert kept > 0, f"Row {i}: no tokens kept despite finite input"
@pytest.mark.parametrize("inf_fraction", [0.5, 0.9, 0.99])
def test_topk_topp_with_neginf_logits(self, inf_fraction: float):
"""Combined top-k + top-p with many -inf logits."""
from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton
batch_size, vocab_size = 32, 128256
logits = torch.randn(
batch_size, vocab_size, generator=self.generator, dtype=torch.float32
)
mask = (
torch.rand(batch_size, vocab_size, generator=self.generator) < inf_fraction
)
logits[mask] = float("-inf")
k = torch.randint(
1, 50, (batch_size,), generator=self.generator, dtype=torch.int32
)
p = (
torch.rand(batch_size, generator=self.generator, dtype=torch.float32) * 0.9
+ 0.1
)
result = apply_top_k_top_p_triton(logits.clone(), k, p)
assert not result.isnan().any(), (
"NaN found in top-k+top-p result with -inf logits"
)
for i in range(batch_size):
kept = (result[i] > float("-inf")).sum().item()
assert kept <= k[i].item(), f"Row {i}: kept {kept} > k={k[i].item()}"
def test_all_neginf_logits(self):
"""All logits are -inf (fully masked). Kernel should be a no-op."""
from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton
batch_size, vocab_size = 16, 128256
logits = torch.full(
(batch_size, vocab_size), float("-inf"), dtype=torch.float32
)
k = torch.randint(
1, 50, (batch_size,), generator=self.generator, dtype=torch.int32
)
p = torch.full((batch_size,), 0.9, dtype=torch.float32)
# top-k only
result = apply_top_k_top_p_triton(logits.clone(), k, None)
assert not result.isnan().any(), "NaN from all-inf top-k"
assert (result == float("-inf")).all(), "Expected all -inf unchanged"
# top-p only
result = apply_top_k_top_p_triton(logits.clone(), None, p)
assert not result.isnan().any(), "NaN from all-inf top-p"
assert (result == float("-inf")).all(), "Expected all -inf unchanged"
# top-k + top-p
result = apply_top_k_top_p_triton(logits.clone(), k, p)
assert not result.isnan().any(), "NaN from all-inf top-k+top-p"
assert (result == float("-inf")).all(), "Expected all -inf unchanged"
def test_few_valid_tokens_with_neginf(self):
"""Only a handful of tokens are finite per row (strict grammar)."""
from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton
batch_size, vocab_size = 32, 128256
logits = torch.full(
(batch_size, vocab_size), float("-inf"), dtype=torch.float32
)
# Allow only 5 random tokens per row to be finite.
for i in range(batch_size):
indices = torch.randperm(vocab_size, generator=self.generator)[:5]
logits[i, indices] = torch.randn(
5, generator=self.generator, dtype=torch.float32
)
k = torch.full((batch_size,), 50, dtype=torch.int32)
p = torch.full((batch_size,), 0.9, dtype=torch.float32)
# top-k only (k=50 but only 5 finite → keep all 5)
result = apply_top_k_top_p_triton(logits.clone(), k, None)
assert not result.isnan().any()
for i in range(batch_size):
kept = (result[i] > float("-inf")).sum().item()
assert kept == 5, f"Row {i}: expected 5 kept, got {kept}"
# top-k with k < num_finite
k_small = torch.full((batch_size,), 3, dtype=torch.int32)
result = apply_top_k_top_p_triton(logits.clone(), k_small, None)
assert not result.isnan().any()
for i in range(batch_size):
kept = (result[i] > float("-inf")).sum().item()
assert kept <= 3, f"Row {i}: expected <=3 kept, got {kept}"
# top-p only
result = apply_top_k_top_p_triton(logits.clone(), None, p)
assert not result.isnan().any()
for i in range(batch_size):
kept = (result[i] > float("-inf")).sum().item()
assert kept > 0, f"Row {i}: no tokens kept"
@pytest.mark.parametrize("num_valid", [1, 2, 5, 10, 50])
@pytest.mark.parametrize(
"mode",
["topk_only", "topp_only", "topk_and_topp"],
)
def test_equal_logits_few_valid(self, num_valid: int, mode: str):
"""Few valid tokens all sharing the same logit value.
This is the pattern produced by grammar bitmask filtering when
the model assigns similar scores to the few allowed tokens.
The ternary search can converge to a pivot equal to max_logit,
causing the strict `>` keep_mask to exclude everything.
Regression test for the `final_pivot >= max_logit` guard.
"""
from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton
batch_size, vocab_size = 32, 128256
logits = torch.full(
(batch_size, vocab_size), float("-inf"), dtype=torch.float32
)
# Set exactly `num_valid` tokens per row to the SAME finite value.
for i in range(batch_size):
indices = torch.randperm(vocab_size, generator=self.generator)[:num_valid]
logits[i, indices] = 1.0 # all equal
k: torch.Tensor | None = None
p: torch.Tensor | None = None
if mode in ("topk_only", "topk_and_topp"):
k = torch.full((batch_size,), max(1, num_valid - 1), dtype=torch.int32)
if mode in ("topp_only", "topk_and_topp"):
p = torch.full((batch_size,), 0.95, dtype=torch.float32)
result = apply_top_k_top_p_triton(logits.clone(), k, p)
assert not result.isnan().any(), "NaN in equal-logit result"
for i in range(batch_size):
kept = (result[i] > float("-inf")).sum().item()
# The key invariant: at least one token must survive.
# With all-equal logits the pivot search can't differentiate
# tokens, so the guard may keep more than k — that is the
# intended safe fallback.
assert kept > 0, (
f"Row {i}: all tokens masked with {num_valid} equal-valued "
f"finite logits ({mode})"
)
@pytest.mark.parametrize("num_valid", [2, 5, 10])
def test_nearly_equal_logits_topp(self, num_valid: int):
"""Few valid tokens with very similar (but not identical) logits.
Ensures the kernel handles near-degenerate probability
distributions where the ternary search range collapses.
"""
from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton
batch_size, vocab_size = 32, 128256
logits = torch.full(
(batch_size, vocab_size), float("-inf"), dtype=torch.float32
)
for i in range(batch_size):
indices = torch.randperm(vocab_size, generator=self.generator)[:num_valid]
# Tiny spread: values in [1.0, 1.0 + 1e-6]
logits[i, indices] = (
1.0
+ torch.rand(num_valid, generator=self.generator, dtype=torch.float32)
* 1e-6
)
p = torch.full((batch_size,), 0.95, dtype=torch.float32)
result = apply_top_k_top_p_triton(logits.clone(), None, p)
assert not result.isnan().any(), "NaN in nearly-equal-logit result"
for i in range(batch_size):
kept = (result[i] > float("-inf")).sum().item()
assert kept > 0, (
f"Row {i}: all tokens masked with {num_valid} "
f"nearly-equal finite logits"
)
def test_mixed_neginf_and_normal_rows(self):
"""Batch with a mix of normal rows and heavily-masked rows."""
from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton
batch_size, vocab_size = 32, 32000
logits = torch.randn(
batch_size, vocab_size, generator=self.generator, dtype=torch.float32
)
# Mask even rows heavily (99% -inf), leave odd rows normal.
for i in range(0, batch_size, 2):
mask = torch.rand(vocab_size, generator=self.generator) < 0.99
logits[i][mask] = float("-inf")
k = torch.randint(
1, 50, (batch_size,), generator=self.generator, dtype=torch.int32
)
p = (
torch.rand(batch_size, generator=self.generator, dtype=torch.float32) * 0.9
+ 0.1
)
result = apply_top_k_top_p_triton(logits.clone(), k, p)
assert not result.isnan().any(), "NaN in mixed normal/-inf batch"
for i in range(batch_size):
kept = (result[i] > float("-inf")).sum().item()
assert kept <= k[i].item()
finite_in = (logits[i] > float("-inf")).sum().item()
if finite_in > 0:
assert kept > 0, f"Row {i}: no tokens kept"

View File

@@ -14,16 +14,12 @@ def cdiv(a: int, b: int) -> int:
def next_power_of_2(n: int) -> int:
"""The next power of 2 (inclusive)"""
if n < 1:
return 1
return 1 << (n - 1).bit_length()
return 1 if n < 1 else 1 << (n - 1).bit_length()
def prev_power_of_2(n: int) -> int:
"""The previous power of 2 (inclusive)"""
if n <= 0:
return 0
return 1 << (n.bit_length() - 1)
return 0 if n <= 0 else 1 << (n.bit_length() - 1)
def round_up(x: int, y: int) -> int:

View File

@@ -11,6 +11,10 @@ from vllm._aiter_ops import rocm_aiter_ops
from vllm.config.model import LogprobsMode
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton
logger = init_logger(__name__)
@@ -87,8 +91,6 @@ class TopKTopPSampler(nn.Module):
else:
self.forward = self.forward_native
self.apply_top_k_top_p = apply_top_k_top_p
def forward_native(
self,
logits: torch.Tensor,
@@ -101,7 +103,7 @@ class TopKTopPSampler(nn.Module):
The logits tensor may be updated in-place.
"""
logits = self.apply_top_k_top_p(logits, k, p)
logits = apply_top_k_top_p(logits, k, p)
logits_to_return = None
if self.logprobs_mode == "processed_logits":
logits_to_return = logits
@@ -149,7 +151,7 @@ class TopKTopPSampler(nn.Module):
The logits tensor may be updated in-place.
"""
logits = self.apply_top_k_top_p(logits, k, p)
logits = apply_top_k_top_p_pytorch(logits, k, p, allow_cpu_sync=True)
logits_to_return = None
if self.logprobs_mode == "processed_logits":
logits_to_return = logits
@@ -158,14 +160,14 @@ class TopKTopPSampler(nn.Module):
if len(generators) != logits.shape[0]:
return compiled_random_sample(logits), logits_to_return
else:
probs = logits.softmax(dim=-1, dtype=torch.float32)
q = torch.empty_like(probs)
q.exponential_()
for i, generator in generators.items():
q[i].exponential_(generator=generator)
return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return
probs = logits.softmax(dim=-1, dtype=torch.float32)
q = torch.empty_like(probs)
q.exponential_()
for i, generator in generators.items():
q[i].exponential_(generator=generator)
return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return
def forward_hip(
self,
@@ -241,9 +243,23 @@ def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor:
def apply_top_k_top_p(
logits: torch.Tensor, k: torch.Tensor | None, p: torch.Tensor | None
) -> torch.Tensor:
if p is None and k is None:
return logits
if HAS_TRITON and logits.shape[0] >= 8:
return apply_top_k_top_p_triton(logits, k, p)
# Use pytorch sort implementation for small batch sizes.
return apply_top_k_top_p_pytorch(logits, k, p)
def apply_top_k_top_p_pytorch(
logits: torch.Tensor,
k: torch.Tensor | None,
p: torch.Tensor | None,
allow_cpu_sync: bool = False,
) -> torch.Tensor:
"""Apply top-k and top-p masks to the logits.
@@ -256,8 +272,9 @@ def apply_top_k_top_p(
if k is None:
return logits
# Avoid sorting vocab for top-k only case.
return apply_top_k_only(logits, k)
if allow_cpu_sync:
# Avoid sorting vocab for top-k only case.
return apply_top_k_only(logits, k)
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
@@ -279,18 +296,16 @@ def apply_top_k_top_p(
logits_sort.masked_fill_(top_p_mask, -float("inf"))
# Re-sort the probabilities.
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
return logits
return logits.scatter_(dim=-1, index=logits_idx, src=logits_sort)
def apply_top_k_only(
logits: torch.Tensor,
k: torch.Tensor,
) -> torch.Tensor:
def apply_top_k_only(logits: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
"""
Apply top-k mask to the logits.
This implementation doesn't involve sorting the entire vocab.
Note however that it involves a GPU->CPU sync which can be detrimental for
async scheduling performance.
The logits tensor may be updated in-place.
"""
@@ -304,8 +319,7 @@ def apply_top_k_only(
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
# Handle non-topk rows.
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
logits.masked_fill_(logits < top_k_mask, -float("inf"))
return logits
return logits.masked_fill_(logits < top_k_mask, -float("inf"))
def random_sample(

File diff suppressed because it is too large Load Diff