[Kernel] Triton-based Top-k and Top-p sampler kernels (#33538)
Signed-off-by: js_park <cakeng@naver.com> Signed-off-by: Jongseok Park <37990712+cakeng@users.noreply.github.com> Signed-off-by: Sunga Kim <sunga.kim@berkeley.edu> Signed-off-by: Nick Hill <nickhill123@gmail.com> Co-authored-by: Sunga Kim <sunga.kim@berkeley.edu> Co-authored-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
471
benchmarks/benchmark_topk_topp.py
Normal file
471
benchmarks/benchmark_topk_topp.py
Normal file
@@ -0,0 +1,471 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Benchmark comparing Triton vs PyTorch sort-based top-k/top-p implementations.
|
||||
|
||||
Compares:
|
||||
- apply_top_k_top_p_triton (Triton binary search)
|
||||
- apply_top_k_top_p (PyTorch sort-based)
|
||||
|
||||
Scenarios:
|
||||
- top_k only (whole batch, partial batch)
|
||||
- top_p only (whole batch, partial batch)
|
||||
- mix of top_k and top_p
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_pytorch
|
||||
from vllm.v1.sample.ops.topk_topp_triton import (
|
||||
apply_top_k_top_p_triton,
|
||||
reset_buffer_cache,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkConfig:
|
||||
"""Configuration for a benchmark run."""
|
||||
|
||||
name: str
|
||||
batch_size: int
|
||||
vocab_size: int
|
||||
# k and p can be tensors or None
|
||||
k_values: torch.Tensor | None # [batch_size] or None
|
||||
p_values: torch.Tensor | None # [batch_size] or None
|
||||
description: str
|
||||
ops_pct: float = 0.0 # Percentage of ops relative to batch size
|
||||
|
||||
|
||||
def calculate_ops_pct(
|
||||
k_values: torch.Tensor | None,
|
||||
p_values: torch.Tensor | None,
|
||||
vocab_size: int,
|
||||
batch_size: int,
|
||||
) -> float:
|
||||
"""
|
||||
Calculate the percentage of active top-k and top-p operations.
|
||||
|
||||
Returns percentage where 100% = batch_size ops.
|
||||
E.g., if all rows have both top-k and top-p active, returns 200%.
|
||||
"""
|
||||
active_ops = 0
|
||||
|
||||
if k_values is not None:
|
||||
# Count rows where k < vocab_size (active top-k filtering)
|
||||
active_ops += (k_values < vocab_size).sum().item()
|
||||
|
||||
if p_values is not None:
|
||||
# Count rows where p < 1.0 (active top-p filtering)
|
||||
active_ops += (p_values < 1.0).sum().item()
|
||||
|
||||
return (active_ops / batch_size) * 100 if batch_size > 0 else 0.0
|
||||
|
||||
|
||||
def create_logits(
|
||||
batch_size: int, vocab_size: int, device: str = "cuda"
|
||||
) -> torch.Tensor:
|
||||
"""Create random logits mimicking a realistic LLM distribution.
|
||||
|
||||
Uses a Zipf-like probability distribution (rank^-1.1) converted to logits
|
||||
via log, then randomly permuted per row. This produces a peaked distribution
|
||||
where a small number of tokens capture most probability mass, similar to
|
||||
real model outputs.
|
||||
"""
|
||||
# Create Zipf-like probabilities: p(rank) ~ rank^(-alpha)
|
||||
ranks = torch.arange(1, vocab_size + 1, dtype=torch.float32, device=device)
|
||||
probs = ranks.pow(-1.1)
|
||||
probs = probs / probs.sum()
|
||||
|
||||
# Convert to logits (log-probabilities, unnormalized is fine)
|
||||
base_logits = probs.log()
|
||||
|
||||
# Broadcast to batch and randomly permute each row
|
||||
logits = base_logits.unsqueeze(0).expand(batch_size, -1).clone()
|
||||
for i in range(batch_size):
|
||||
logits[i] = logits[i, torch.randperm(vocab_size, device=device)]
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def measure_memory() -> tuple[int, int]:
|
||||
"""Return (allocated, reserved) memory in bytes."""
|
||||
torch.cuda.synchronize()
|
||||
return torch.cuda.memory_allocated(), torch.cuda.max_memory_allocated()
|
||||
|
||||
|
||||
def reset_memory_stats():
|
||||
"""Reset peak memory statistics."""
|
||||
reset_buffer_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
def benchmark_function(
|
||||
func,
|
||||
logits: torch.Tensor,
|
||||
k: torch.Tensor | None,
|
||||
p: torch.Tensor | None,
|
||||
warmup_iters: int = 5,
|
||||
benchmark_iters: int = 20,
|
||||
) -> tuple[float, int]:
|
||||
"""
|
||||
Benchmark a function and return (avg_time_ms, peak_memory_bytes).
|
||||
|
||||
Returns average time in milliseconds and peak memory usage.
|
||||
"""
|
||||
# Warmup
|
||||
for _ in range(warmup_iters):
|
||||
logits_copy = logits.clone()
|
||||
func(logits_copy, k, p)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Reset memory stats before benchmark
|
||||
reset_memory_stats()
|
||||
|
||||
# Benchmark
|
||||
start_events = [
|
||||
torch.cuda.Event(enable_timing=True) for _ in range(benchmark_iters)
|
||||
]
|
||||
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(benchmark_iters)]
|
||||
|
||||
for i in range(benchmark_iters):
|
||||
logits_copy = logits.clone()
|
||||
start_events[i].record()
|
||||
func(logits_copy, k, p)
|
||||
end_events[i].record()
|
||||
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Calculate timing
|
||||
times = [
|
||||
start_events[i].elapsed_time(end_events[i]) for i in range(benchmark_iters)
|
||||
]
|
||||
avg_time = sum(times) / len(times)
|
||||
|
||||
# Get peak memory
|
||||
_, peak_memory = measure_memory()
|
||||
|
||||
return avg_time, peak_memory
|
||||
|
||||
|
||||
def create_benchmark_configs(
|
||||
batch_sizes: list[int],
|
||||
vocab_sizes: list[int],
|
||||
device: str = "cuda",
|
||||
) -> list[BenchmarkConfig]:
|
||||
"""Create all benchmark configurations."""
|
||||
configs = []
|
||||
|
||||
for vocab_size in vocab_sizes:
|
||||
for batch_size in batch_sizes:
|
||||
# 1. Top-k only - whole batch (all rows have k < vocab_size)
|
||||
k_all = torch.full((batch_size,), 50, dtype=torch.int32, device=device)
|
||||
configs.append(
|
||||
BenchmarkConfig(
|
||||
name=f"topk_whole_b{batch_size}_v{vocab_size // 1000}k",
|
||||
batch_size=batch_size,
|
||||
vocab_size=vocab_size,
|
||||
k_values=k_all,
|
||||
p_values=None,
|
||||
description=f"Top-k only (whole batch, k=50), "
|
||||
f"batch={batch_size}, vocab={vocab_size}",
|
||||
ops_pct=calculate_ops_pct(k_all, None, vocab_size, batch_size),
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Top-k only - partial batch (half have k=50, half have k=vocab_size)
|
||||
k_partial = torch.full((batch_size,), 50, dtype=torch.int32, device=device)
|
||||
k_partial[batch_size // 2 :] = vocab_size # No filtering for second half
|
||||
configs.append(
|
||||
BenchmarkConfig(
|
||||
name=f"topk_partial_b{batch_size}_v{vocab_size // 1000}k",
|
||||
batch_size=batch_size,
|
||||
vocab_size=vocab_size,
|
||||
k_values=k_partial,
|
||||
p_values=None,
|
||||
description=f"Top-k only (partial batch, 50% k=50, 50% k=vocab), "
|
||||
f"batch={batch_size}, vocab={vocab_size}",
|
||||
ops_pct=calculate_ops_pct(k_partial, None, vocab_size, batch_size),
|
||||
)
|
||||
)
|
||||
|
||||
# 3. Top-p only - whole batch (all rows have p < 1.0)
|
||||
p_all = torch.full((batch_size,), 0.9, dtype=torch.float32, device=device)
|
||||
configs.append(
|
||||
BenchmarkConfig(
|
||||
name=f"topp_whole_b{batch_size}_v{vocab_size // 1000}k",
|
||||
batch_size=batch_size,
|
||||
vocab_size=vocab_size,
|
||||
k_values=None,
|
||||
p_values=p_all,
|
||||
description=f"Top-p only (whole batch, p=0.9), "
|
||||
f"batch={batch_size}, vocab={vocab_size}",
|
||||
ops_pct=calculate_ops_pct(None, p_all, vocab_size, batch_size),
|
||||
)
|
||||
)
|
||||
|
||||
# 4. Top-p only - partial batch (half have p=0.9, half have p=1.0)
|
||||
p_partial = torch.full(
|
||||
(batch_size,), 0.9, dtype=torch.float32, device=device
|
||||
)
|
||||
p_partial[batch_size // 2 :] = 1.0 # No filtering for second half
|
||||
configs.append(
|
||||
BenchmarkConfig(
|
||||
name=f"topp_partial_b{batch_size}_v{vocab_size // 1000}k",
|
||||
batch_size=batch_size,
|
||||
vocab_size=vocab_size,
|
||||
k_values=None,
|
||||
p_values=p_partial,
|
||||
description=f"Top-p only (partial batch, 50% p=0.9, 50% p=1.0), "
|
||||
f"batch={batch_size}, vocab={vocab_size}",
|
||||
ops_pct=calculate_ops_pct(None, p_partial, vocab_size, batch_size),
|
||||
)
|
||||
)
|
||||
|
||||
# 5. Mix of top-k and top-p (both applied to whole batch)
|
||||
k_mix = torch.full((batch_size,), 100, dtype=torch.int32, device=device)
|
||||
p_mix = torch.full((batch_size,), 0.9, dtype=torch.float32, device=device)
|
||||
configs.append(
|
||||
BenchmarkConfig(
|
||||
name=f"topk_topp_whole_b{batch_size}_v{vocab_size // 1000}k",
|
||||
batch_size=batch_size,
|
||||
vocab_size=vocab_size,
|
||||
k_values=k_mix,
|
||||
p_values=p_mix,
|
||||
description=f"Top-k + Top-p (whole batch, k=100, p=0.9), "
|
||||
f"batch={batch_size}, vocab={vocab_size}",
|
||||
ops_pct=calculate_ops_pct(k_mix, p_mix, vocab_size, batch_size),
|
||||
)
|
||||
)
|
||||
|
||||
# 6. Mix with partial application (some rows k only, some p only, some both)
|
||||
k_mixed = torch.full(
|
||||
(batch_size,), vocab_size, dtype=torch.int32, device=device
|
||||
)
|
||||
p_mixed = torch.full((batch_size,), 1.0, dtype=torch.float32, device=device)
|
||||
# First third: k only
|
||||
third = batch_size // 3
|
||||
k_mixed[:third] = 50
|
||||
# Second third: p only
|
||||
p_mixed[third : 2 * third] = 0.5
|
||||
# Last third: both k and p
|
||||
k_mixed[2 * third :] = 100
|
||||
p_mixed[2 * third :] = 0.9
|
||||
configs.append(
|
||||
BenchmarkConfig(
|
||||
name=f"mixed_partial_b{batch_size}_v{vocab_size // 1000}k",
|
||||
batch_size=batch_size,
|
||||
vocab_size=vocab_size,
|
||||
k_values=k_mixed,
|
||||
p_values=p_mixed,
|
||||
description=f"Mixed partial (1/3 k=50, 1/3 p=0.9, 1/3 both), "
|
||||
f"batch={batch_size}, vocab={vocab_size}",
|
||||
ops_pct=calculate_ops_pct(k_mixed, p_mixed, vocab_size, batch_size),
|
||||
)
|
||||
)
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
def format_memory(bytes_val: int) -> str:
|
||||
"""Format memory in human-readable form."""
|
||||
if bytes_val >= 1024**3:
|
||||
return f"{bytes_val / (1024**3):.2f} GB"
|
||||
elif bytes_val >= 1024**2:
|
||||
return f"{bytes_val / (1024**2):.2f} MB"
|
||||
elif bytes_val >= 1024:
|
||||
return f"{bytes_val / 1024:.2f} KB"
|
||||
return f"{bytes_val} B"
|
||||
|
||||
|
||||
def run_benchmark(
|
||||
configs: list[BenchmarkConfig],
|
||||
warmup_iters: int = 5,
|
||||
benchmark_iters: int = 20,
|
||||
verbose: bool = True,
|
||||
):
|
||||
"""Run all benchmarks and print results."""
|
||||
results = []
|
||||
|
||||
print("=" * 100)
|
||||
print("Top-k/Top-p Benchmark: Triton vs PyTorch Sort-based")
|
||||
print("=" * 100)
|
||||
print()
|
||||
|
||||
for config in configs:
|
||||
if verbose:
|
||||
print(f"Running: {config.description}")
|
||||
|
||||
# Create fresh logits for this config
|
||||
logits = create_logits(config.batch_size, config.vocab_size)
|
||||
|
||||
# Benchmark Triton
|
||||
reset_memory_stats()
|
||||
triton_time, triton_mem = benchmark_function(
|
||||
apply_top_k_top_p_triton,
|
||||
logits,
|
||||
config.k_values,
|
||||
config.p_values,
|
||||
warmup_iters,
|
||||
benchmark_iters,
|
||||
)
|
||||
|
||||
# Benchmark PyTorch
|
||||
reset_memory_stats()
|
||||
pytorch_time, pytorch_mem = benchmark_function(
|
||||
apply_top_k_top_p_pytorch,
|
||||
logits,
|
||||
config.k_values,
|
||||
config.p_values,
|
||||
warmup_iters,
|
||||
benchmark_iters,
|
||||
)
|
||||
|
||||
speedup = pytorch_time / triton_time if triton_time > 0 else float("inf")
|
||||
mem_ratio = pytorch_mem / triton_mem if triton_mem > 0 else float("inf")
|
||||
|
||||
result = {
|
||||
"config": config,
|
||||
"triton_time_ms": triton_time,
|
||||
"pytorch_time_ms": pytorch_time,
|
||||
"triton_mem": triton_mem,
|
||||
"pytorch_mem": pytorch_mem,
|
||||
"speedup": speedup,
|
||||
"mem_ratio": mem_ratio,
|
||||
}
|
||||
results.append(result)
|
||||
|
||||
if verbose:
|
||||
print(f" Triton: {triton_time:.3f} ms, {format_memory(triton_mem)}")
|
||||
print(f" PyTorch: {pytorch_time:.3f} ms, {format_memory(pytorch_mem)}")
|
||||
print(f" Speedup: {speedup:.2f}x, Memory ratio: {mem_ratio:.2f}x")
|
||||
print()
|
||||
|
||||
# Clean up
|
||||
del logits
|
||||
reset_memory_stats()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def print_summary_table(results: list[dict]):
|
||||
"""Print a summary table of results."""
|
||||
print()
|
||||
print("=" * 130)
|
||||
print("SUMMARY TABLE")
|
||||
print("=" * 130)
|
||||
print()
|
||||
|
||||
# Header
|
||||
header = (
|
||||
f"{'Scenario':<40} {'Batch':>6} {'Vocab':>7} {'Ops%':>6} "
|
||||
f"{'Triton (ms)':>12} {'PyTorch (ms)':>13} {'Speedup':>8} "
|
||||
f"{'Tri Mem':>10} {'Pyt Mem':>10}"
|
||||
)
|
||||
print(header)
|
||||
print("-" * 130)
|
||||
|
||||
# Group by scenario type
|
||||
current_vocab = None
|
||||
for result in results:
|
||||
config = result["config"]
|
||||
|
||||
# Add separator between vocab sizes
|
||||
if current_vocab != config.vocab_size:
|
||||
if current_vocab is not None:
|
||||
print("-" * 130)
|
||||
current_vocab = config.vocab_size
|
||||
|
||||
scenario = config.name.split("_b")[0] # Extract scenario name
|
||||
print(
|
||||
f"{scenario:<40} {config.batch_size:>6} {config.vocab_size:>7} "
|
||||
f"{config.ops_pct:>5.0f}% "
|
||||
f"{result['triton_time_ms']:>12.3f} {result['pytorch_time_ms']:>13.3f} "
|
||||
f"{result['speedup']:>7.2f}x "
|
||||
f"{format_memory(result['triton_mem']):>10} "
|
||||
f"{format_memory(result['pytorch_mem']):>10}"
|
||||
)
|
||||
|
||||
print("=" * 130)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark Triton vs PyTorch sort-based top-k/top-p implementations"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-sizes",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[1, 4, 16, 64, 128, 512, 1024, 2048],
|
||||
help="Batch sizes to test (default: 1 4 16 64)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vocab-sizes",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[32768, 131072], # 32k, 128k
|
||||
help="Vocabulary sizes to test (default: 32768 131072)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--warmup-iters",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of warmup iterations (default: 5)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--benchmark-iters",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Number of benchmark iterations (default: 20)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quiet",
|
||||
action="store_true",
|
||||
help="Only print summary table",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Print configuration
|
||||
print(f"Batch sizes: {args.batch_sizes}")
|
||||
print(f"Vocab sizes: {args.vocab_sizes}")
|
||||
print(f"Warmup iterations: {args.warmup_iters}")
|
||||
print(f"Benchmark iterations: {args.benchmark_iters}")
|
||||
print()
|
||||
|
||||
# Check CUDA
|
||||
if not torch.cuda.is_available():
|
||||
print("ERROR: CUDA is not available. This benchmark requires a GPU.")
|
||||
return
|
||||
|
||||
device_name = torch.cuda.get_device_name(0)
|
||||
print(f"GPU: {device_name}")
|
||||
print()
|
||||
|
||||
# Create configs
|
||||
configs = create_benchmark_configs(
|
||||
args.batch_sizes,
|
||||
args.vocab_sizes,
|
||||
)
|
||||
|
||||
# Run benchmarks
|
||||
results = run_benchmark(
|
||||
configs,
|
||||
warmup_iters=args.warmup_iters,
|
||||
benchmark_iters=args.benchmark_iters,
|
||||
verbose=not args.quiet,
|
||||
)
|
||||
|
||||
# Print summary
|
||||
print_summary_table(results)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -145,6 +145,7 @@ async def test_request_cancellation(server: RemoteOpenAIServer):
|
||||
model=MODEL_NAME,
|
||||
max_tokens=10000,
|
||||
extra_body={"min_tokens": 10000},
|
||||
temperature=0.0,
|
||||
)
|
||||
)
|
||||
tasks.append(task)
|
||||
@@ -163,7 +164,7 @@ async def test_request_cancellation(server: RemoteOpenAIServer):
|
||||
# be able to respond to this one within the timeout
|
||||
client = server.get_async_client(timeout=5)
|
||||
response = await client.chat.completions.create(
|
||||
messages=chat_input, model=MODEL_NAME, max_tokens=10
|
||||
messages=chat_input, model=MODEL_NAME, max_tokens=10, temperature=0.0
|
||||
)
|
||||
|
||||
assert len(response.choices) == 1
|
||||
|
||||
@@ -5,8 +5,9 @@ import torch
|
||||
from torch import Generator
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_pytorch
|
||||
|
||||
CUDA_DEVICE = "cuda" if current_platform.is_cuda() else None
|
||||
DEVICE = current_platform.device_type
|
||||
|
||||
BATCH_SIZE = 1024
|
||||
@@ -39,11 +40,11 @@ def test_topk_impl_equivalence():
|
||||
)
|
||||
|
||||
# Top-k only implementation
|
||||
result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
|
||||
result1 = apply_top_k_top_p_pytorch(logits=logits.clone(), k=k, p=None)
|
||||
|
||||
# Top-p + top-k
|
||||
no_op_top_p = torch.tensor([1.0])
|
||||
result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p)
|
||||
result2 = apply_top_k_top_p_pytorch(logits=logits.clone(), k=k, p=no_op_top_p)
|
||||
|
||||
assert torch.allclose(result1, result2)
|
||||
|
||||
@@ -98,7 +99,7 @@ def test_flashinfer_sampler():
|
||||
torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=torch.bool), 1.0
|
||||
)
|
||||
|
||||
python_logits = apply_top_k_top_p(
|
||||
python_logits = apply_top_k_top_p_pytorch(
|
||||
logits=logits.clone(),
|
||||
k=k_values,
|
||||
p=p_values,
|
||||
@@ -120,3 +121,451 @@ def test_flashinfer_sampler():
|
||||
assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), (
|
||||
"FlashInfer and Python sampling implementations do not match!"
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Triton kernel tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.skipif(CUDA_DEVICE is None, reason="CUDA not available")
|
||||
class TestTritonTopkTopp:
|
||||
"""Tests for the Triton top-k/top-p kernel."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self):
|
||||
"""Set up test fixtures."""
|
||||
torch.set_default_device(CUDA_DEVICE)
|
||||
self.generator = Generator(device=CUDA_DEVICE).manual_seed(42)
|
||||
|
||||
def _compare_results(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
k: torch.Tensor | None,
|
||||
p: torch.Tensor | None,
|
||||
):
|
||||
"""Compare Triton kernel results with PyTorch sorting implementation.
|
||||
|
||||
For top-k only, we expect exact match.
|
||||
For top-p (with or without top-k), we allow small differences due to
|
||||
floating-point precision in probability sum calculations.
|
||||
"""
|
||||
from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton
|
||||
|
||||
# Clone logits for both implementations
|
||||
logits_pytorch = logits.clone()
|
||||
logits_triton = logits.clone().to(torch.float32)
|
||||
|
||||
# Apply PyTorch sorting implementation
|
||||
result_pytorch = apply_top_k_top_p_pytorch(logits_pytorch, k, p)
|
||||
|
||||
# Apply Triton kernel
|
||||
k_i32 = k.to(torch.int32) if k is not None else None
|
||||
p_f32 = p.to(torch.float32) if p is not None else None
|
||||
result_triton = apply_top_k_top_p_triton(logits_triton, k_i32, p_f32)
|
||||
|
||||
# Compare kept counts per row
|
||||
pytorch_kept = (result_pytorch != float("-inf")).sum(dim=-1)
|
||||
triton_kept = (result_triton != float("-inf")).sum(dim=-1)
|
||||
|
||||
if p is None:
|
||||
# Top-k only: expect exact match
|
||||
assert torch.equal(pytorch_kept, triton_kept), (
|
||||
f"Top-k mask mismatch: PyTorch kept {pytorch_kept.tolist()}, "
|
||||
f"Triton kept {triton_kept.tolist()}"
|
||||
)
|
||||
else:
|
||||
# Top-p involved: allow small differences
|
||||
# Either < 1% of kept values OR < 5 values absolute
|
||||
max_diff = (pytorch_kept - triton_kept).abs().max().item()
|
||||
max_kept = pytorch_kept.max().item()
|
||||
if max_kept > 0 and max_diff > 3:
|
||||
diff_pct = max_diff / max_kept * 100
|
||||
assert diff_pct < 0.5, (
|
||||
f"Top-p mask difference too large: {diff_pct:.2f}% "
|
||||
f"(max diff {max_diff} values out of {max_kept})"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128, 512, 1024])
|
||||
@pytest.mark.parametrize("vocab_size", [1024, 32000, 128256])
|
||||
def test_topk_only(self, batch_size: int, vocab_size: int):
|
||||
"""Test top-k only (p=None)."""
|
||||
logits = torch.randn(
|
||||
batch_size, vocab_size, generator=self.generator, dtype=torch.float32
|
||||
)
|
||||
k = torch.randint(
|
||||
1, min(100, vocab_size), (batch_size,), generator=self.generator
|
||||
)
|
||||
# Randomly disable top-k for some rows (~25%)
|
||||
disable_mask = torch.randint(0, 4, (batch_size,), generator=self.generator) == 0
|
||||
k.masked_fill_(disable_mask, vocab_size)
|
||||
|
||||
self._compare_results(logits, k, p=None)
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128, 512, 1024])
|
||||
@pytest.mark.parametrize("vocab_size", [1024, 32000, 128256])
|
||||
def test_topp_only(self, batch_size: int, vocab_size: int):
|
||||
"""Test top-p only (k=None)."""
|
||||
logits = torch.randn(
|
||||
batch_size, vocab_size, generator=self.generator, dtype=torch.float32
|
||||
)
|
||||
p = torch.rand(batch_size, generator=self.generator) * 0.9 + 0.1 # [0.1, 1.0]
|
||||
# Randomly disable top-p for some rows (~25%)
|
||||
disable_mask = torch.randint(0, 4, (batch_size,), generator=self.generator) == 0
|
||||
p.masked_fill_(disable_mask, 1.0)
|
||||
|
||||
self._compare_results(logits, k=None, p=p)
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128, 512, 1024])
|
||||
@pytest.mark.parametrize("vocab_size", [1024, 32000, 128256])
|
||||
def test_topk_and_topp(self, batch_size: int, vocab_size: int):
|
||||
"""Test combined top-k and top-p."""
|
||||
logits = torch.randn(
|
||||
batch_size, vocab_size, generator=self.generator, dtype=torch.float32
|
||||
)
|
||||
k = torch.randint(
|
||||
1, min(100, vocab_size), (batch_size,), generator=self.generator
|
||||
)
|
||||
p = torch.rand(batch_size, generator=self.generator) * 0.9 + 0.1 # [0.1, 1.0]
|
||||
|
||||
# Randomly disable top-k for some rows (~25%)
|
||||
disable_k = torch.randint(0, 4, (batch_size,), generator=self.generator) == 0
|
||||
k.masked_fill_(disable_k, vocab_size)
|
||||
# Randomly disable top-p for some rows (~25%)
|
||||
disable_p = torch.randint(0, 4, (batch_size,), generator=self.generator) == 0
|
||||
p.masked_fill_(disable_p, 1.0)
|
||||
|
||||
self._compare_results(logits, k, p)
|
||||
|
||||
def test_both_disabled(self):
|
||||
"""Test when both k and p are None (should be no-op)."""
|
||||
from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton
|
||||
|
||||
logits = torch.randn(32, 1024, generator=self.generator, dtype=torch.float32)
|
||||
logits_clone = logits.clone()
|
||||
|
||||
result = apply_top_k_top_p_triton(logits_clone, k=None, p=None)
|
||||
|
||||
assert torch.equal(result, logits), "Should be no-op when both k and p are None"
|
||||
|
||||
def test_extreme_k_values(self):
|
||||
"""Test edge cases for k values."""
|
||||
batch_size, vocab_size = 16, 1024
|
||||
logits = torch.randn(
|
||||
batch_size, vocab_size, generator=self.generator, dtype=torch.float32
|
||||
)
|
||||
|
||||
# k=1 (keep only top 1)
|
||||
k = torch.ones(batch_size, dtype=torch.int32)
|
||||
self._compare_results(logits.clone(), k, p=None)
|
||||
|
||||
# k=vocab_size (keep all)
|
||||
k = torch.full((batch_size,), vocab_size, dtype=torch.int32)
|
||||
self._compare_results(logits.clone(), k, p=None)
|
||||
|
||||
# Mixed extreme values
|
||||
k = torch.tensor([1, vocab_size, 2, vocab_size - 1] * 4, dtype=torch.int32)
|
||||
self._compare_results(logits.clone(), k, p=None)
|
||||
|
||||
def test_extreme_p_values(self):
|
||||
"""Test edge cases for p values."""
|
||||
batch_size, vocab_size = 16, 1024
|
||||
logits = torch.randn(
|
||||
batch_size, vocab_size, generator=self.generator, dtype=torch.float32
|
||||
)
|
||||
|
||||
# p close to 0 (very restrictive)
|
||||
p = torch.full((batch_size,), 0.01, dtype=torch.float32)
|
||||
self._compare_results(logits.clone(), k=None, p=p)
|
||||
|
||||
# p=1.0 (keep all)
|
||||
p = torch.ones(batch_size, dtype=torch.float32)
|
||||
self._compare_results(logits.clone(), k=None, p=p)
|
||||
|
||||
# Mixed values
|
||||
p = torch.tensor([0.1, 0.5, 0.9, 1.0] * 4, dtype=torch.float32)
|
||||
self._compare_results(logits.clone(), k=None, p=p)
|
||||
|
||||
def test_large_batch(self):
|
||||
"""Test with a large batch size."""
|
||||
batch_size, vocab_size = 512, 32000
|
||||
logits = torch.randn(
|
||||
batch_size, vocab_size, generator=self.generator, dtype=torch.float32
|
||||
)
|
||||
k = torch.randint(1, 50, (batch_size,), generator=self.generator)
|
||||
p = torch.rand(batch_size, generator=self.generator) * 0.5 + 0.5
|
||||
|
||||
self._compare_results(logits, k, p)
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Tests for -inf logits (e.g. from grammar / structured output masks)
|
||||
# -----------------------------------------------------------------
|
||||
|
||||
@pytest.mark.parametrize("inf_fraction", [0.5, 0.9, 0.99])
|
||||
def test_topk_with_neginf_logits(self, inf_fraction: float):
|
||||
"""Top-k with many -inf logits (simulating grammar bitmask).
|
||||
|
||||
The kernel must not produce NaN when most logits are -inf, which
|
||||
can happen when structured-output grammar masks are applied before
|
||||
sampling.
|
||||
"""
|
||||
from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton
|
||||
|
||||
batch_size, vocab_size = 32, 128256
|
||||
logits = torch.randn(
|
||||
batch_size, vocab_size, generator=self.generator, dtype=torch.float32
|
||||
)
|
||||
# Mask a fraction of logits to -inf.
|
||||
mask = (
|
||||
torch.rand(batch_size, vocab_size, generator=self.generator) < inf_fraction
|
||||
)
|
||||
logits[mask] = float("-inf")
|
||||
|
||||
k = torch.randint(
|
||||
1, 50, (batch_size,), generator=self.generator, dtype=torch.int32
|
||||
)
|
||||
result = apply_top_k_top_p_triton(logits.clone(), k, None)
|
||||
|
||||
assert not result.isnan().any(), "NaN found in top-k result with -inf logits"
|
||||
for i in range(batch_size):
|
||||
kept = (result[i] > float("-inf")).sum().item()
|
||||
assert kept <= k[i].item(), f"Row {i}: kept {kept} > k={k[i].item()}"
|
||||
# At least one value should survive unless the row was all -inf.
|
||||
finite_in = (logits[i] > float("-inf")).sum().item()
|
||||
if finite_in > 0:
|
||||
assert kept > 0, f"Row {i}: no tokens kept despite finite input"
|
||||
|
||||
@pytest.mark.parametrize("inf_fraction", [0.5, 0.9, 0.99])
|
||||
def test_topp_with_neginf_logits(self, inf_fraction: float):
|
||||
"""Top-p with many -inf logits."""
|
||||
from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton
|
||||
|
||||
batch_size, vocab_size = 32, 128256
|
||||
logits = torch.randn(
|
||||
batch_size, vocab_size, generator=self.generator, dtype=torch.float32
|
||||
)
|
||||
mask = (
|
||||
torch.rand(batch_size, vocab_size, generator=self.generator) < inf_fraction
|
||||
)
|
||||
logits[mask] = float("-inf")
|
||||
|
||||
p = (
|
||||
torch.rand(batch_size, generator=self.generator, dtype=torch.float32) * 0.9
|
||||
+ 0.1
|
||||
)
|
||||
result = apply_top_k_top_p_triton(logits.clone(), None, p)
|
||||
|
||||
assert not result.isnan().any(), "NaN found in top-p result with -inf logits"
|
||||
for i in range(batch_size):
|
||||
finite_in = (logits[i] > float("-inf")).sum().item()
|
||||
kept = (result[i] > float("-inf")).sum().item()
|
||||
if finite_in > 0:
|
||||
assert kept > 0, f"Row {i}: no tokens kept despite finite input"
|
||||
|
||||
@pytest.mark.parametrize("inf_fraction", [0.5, 0.9, 0.99])
|
||||
def test_topk_topp_with_neginf_logits(self, inf_fraction: float):
|
||||
"""Combined top-k + top-p with many -inf logits."""
|
||||
from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton
|
||||
|
||||
batch_size, vocab_size = 32, 128256
|
||||
logits = torch.randn(
|
||||
batch_size, vocab_size, generator=self.generator, dtype=torch.float32
|
||||
)
|
||||
mask = (
|
||||
torch.rand(batch_size, vocab_size, generator=self.generator) < inf_fraction
|
||||
)
|
||||
logits[mask] = float("-inf")
|
||||
|
||||
k = torch.randint(
|
||||
1, 50, (batch_size,), generator=self.generator, dtype=torch.int32
|
||||
)
|
||||
p = (
|
||||
torch.rand(batch_size, generator=self.generator, dtype=torch.float32) * 0.9
|
||||
+ 0.1
|
||||
)
|
||||
result = apply_top_k_top_p_triton(logits.clone(), k, p)
|
||||
|
||||
assert not result.isnan().any(), (
|
||||
"NaN found in top-k+top-p result with -inf logits"
|
||||
)
|
||||
for i in range(batch_size):
|
||||
kept = (result[i] > float("-inf")).sum().item()
|
||||
assert kept <= k[i].item(), f"Row {i}: kept {kept} > k={k[i].item()}"
|
||||
|
||||
def test_all_neginf_logits(self):
|
||||
"""All logits are -inf (fully masked). Kernel should be a no-op."""
|
||||
from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton
|
||||
|
||||
batch_size, vocab_size = 16, 128256
|
||||
logits = torch.full(
|
||||
(batch_size, vocab_size), float("-inf"), dtype=torch.float32
|
||||
)
|
||||
|
||||
k = torch.randint(
|
||||
1, 50, (batch_size,), generator=self.generator, dtype=torch.int32
|
||||
)
|
||||
p = torch.full((batch_size,), 0.9, dtype=torch.float32)
|
||||
|
||||
# top-k only
|
||||
result = apply_top_k_top_p_triton(logits.clone(), k, None)
|
||||
assert not result.isnan().any(), "NaN from all-inf top-k"
|
||||
assert (result == float("-inf")).all(), "Expected all -inf unchanged"
|
||||
|
||||
# top-p only
|
||||
result = apply_top_k_top_p_triton(logits.clone(), None, p)
|
||||
assert not result.isnan().any(), "NaN from all-inf top-p"
|
||||
assert (result == float("-inf")).all(), "Expected all -inf unchanged"
|
||||
|
||||
# top-k + top-p
|
||||
result = apply_top_k_top_p_triton(logits.clone(), k, p)
|
||||
assert not result.isnan().any(), "NaN from all-inf top-k+top-p"
|
||||
assert (result == float("-inf")).all(), "Expected all -inf unchanged"
|
||||
|
||||
def test_few_valid_tokens_with_neginf(self):
|
||||
"""Only a handful of tokens are finite per row (strict grammar)."""
|
||||
from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton
|
||||
|
||||
batch_size, vocab_size = 32, 128256
|
||||
logits = torch.full(
|
||||
(batch_size, vocab_size), float("-inf"), dtype=torch.float32
|
||||
)
|
||||
# Allow only 5 random tokens per row to be finite.
|
||||
for i in range(batch_size):
|
||||
indices = torch.randperm(vocab_size, generator=self.generator)[:5]
|
||||
logits[i, indices] = torch.randn(
|
||||
5, generator=self.generator, dtype=torch.float32
|
||||
)
|
||||
|
||||
k = torch.full((batch_size,), 50, dtype=torch.int32)
|
||||
p = torch.full((batch_size,), 0.9, dtype=torch.float32)
|
||||
|
||||
# top-k only (k=50 but only 5 finite → keep all 5)
|
||||
result = apply_top_k_top_p_triton(logits.clone(), k, None)
|
||||
assert not result.isnan().any()
|
||||
for i in range(batch_size):
|
||||
kept = (result[i] > float("-inf")).sum().item()
|
||||
assert kept == 5, f"Row {i}: expected 5 kept, got {kept}"
|
||||
|
||||
# top-k with k < num_finite
|
||||
k_small = torch.full((batch_size,), 3, dtype=torch.int32)
|
||||
result = apply_top_k_top_p_triton(logits.clone(), k_small, None)
|
||||
assert not result.isnan().any()
|
||||
for i in range(batch_size):
|
||||
kept = (result[i] > float("-inf")).sum().item()
|
||||
assert kept <= 3, f"Row {i}: expected <=3 kept, got {kept}"
|
||||
|
||||
# top-p only
|
||||
result = apply_top_k_top_p_triton(logits.clone(), None, p)
|
||||
assert not result.isnan().any()
|
||||
for i in range(batch_size):
|
||||
kept = (result[i] > float("-inf")).sum().item()
|
||||
assert kept > 0, f"Row {i}: no tokens kept"
|
||||
|
||||
@pytest.mark.parametrize("num_valid", [1, 2, 5, 10, 50])
|
||||
@pytest.mark.parametrize(
|
||||
"mode",
|
||||
["topk_only", "topp_only", "topk_and_topp"],
|
||||
)
|
||||
def test_equal_logits_few_valid(self, num_valid: int, mode: str):
|
||||
"""Few valid tokens all sharing the same logit value.
|
||||
|
||||
This is the pattern produced by grammar bitmask filtering when
|
||||
the model assigns similar scores to the few allowed tokens.
|
||||
The ternary search can converge to a pivot equal to max_logit,
|
||||
causing the strict `>` keep_mask to exclude everything.
|
||||
Regression test for the `final_pivot >= max_logit` guard.
|
||||
"""
|
||||
from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton
|
||||
|
||||
batch_size, vocab_size = 32, 128256
|
||||
logits = torch.full(
|
||||
(batch_size, vocab_size), float("-inf"), dtype=torch.float32
|
||||
)
|
||||
# Set exactly `num_valid` tokens per row to the SAME finite value.
|
||||
for i in range(batch_size):
|
||||
indices = torch.randperm(vocab_size, generator=self.generator)[:num_valid]
|
||||
logits[i, indices] = 1.0 # all equal
|
||||
|
||||
k: torch.Tensor | None = None
|
||||
p: torch.Tensor | None = None
|
||||
if mode in ("topk_only", "topk_and_topp"):
|
||||
k = torch.full((batch_size,), max(1, num_valid - 1), dtype=torch.int32)
|
||||
if mode in ("topp_only", "topk_and_topp"):
|
||||
p = torch.full((batch_size,), 0.95, dtype=torch.float32)
|
||||
|
||||
result = apply_top_k_top_p_triton(logits.clone(), k, p)
|
||||
|
||||
assert not result.isnan().any(), "NaN in equal-logit result"
|
||||
for i in range(batch_size):
|
||||
kept = (result[i] > float("-inf")).sum().item()
|
||||
# The key invariant: at least one token must survive.
|
||||
# With all-equal logits the pivot search can't differentiate
|
||||
# tokens, so the guard may keep more than k — that is the
|
||||
# intended safe fallback.
|
||||
assert kept > 0, (
|
||||
f"Row {i}: all tokens masked with {num_valid} equal-valued "
|
||||
f"finite logits ({mode})"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("num_valid", [2, 5, 10])
|
||||
def test_nearly_equal_logits_topp(self, num_valid: int):
|
||||
"""Few valid tokens with very similar (but not identical) logits.
|
||||
|
||||
Ensures the kernel handles near-degenerate probability
|
||||
distributions where the ternary search range collapses.
|
||||
"""
|
||||
from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton
|
||||
|
||||
batch_size, vocab_size = 32, 128256
|
||||
logits = torch.full(
|
||||
(batch_size, vocab_size), float("-inf"), dtype=torch.float32
|
||||
)
|
||||
for i in range(batch_size):
|
||||
indices = torch.randperm(vocab_size, generator=self.generator)[:num_valid]
|
||||
# Tiny spread: values in [1.0, 1.0 + 1e-6]
|
||||
logits[i, indices] = (
|
||||
1.0
|
||||
+ torch.rand(num_valid, generator=self.generator, dtype=torch.float32)
|
||||
* 1e-6
|
||||
)
|
||||
|
||||
p = torch.full((batch_size,), 0.95, dtype=torch.float32)
|
||||
result = apply_top_k_top_p_triton(logits.clone(), None, p)
|
||||
|
||||
assert not result.isnan().any(), "NaN in nearly-equal-logit result"
|
||||
for i in range(batch_size):
|
||||
kept = (result[i] > float("-inf")).sum().item()
|
||||
assert kept > 0, (
|
||||
f"Row {i}: all tokens masked with {num_valid} "
|
||||
f"nearly-equal finite logits"
|
||||
)
|
||||
|
||||
def test_mixed_neginf_and_normal_rows(self):
|
||||
"""Batch with a mix of normal rows and heavily-masked rows."""
|
||||
from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton
|
||||
|
||||
batch_size, vocab_size = 32, 32000
|
||||
logits = torch.randn(
|
||||
batch_size, vocab_size, generator=self.generator, dtype=torch.float32
|
||||
)
|
||||
# Mask even rows heavily (99% -inf), leave odd rows normal.
|
||||
for i in range(0, batch_size, 2):
|
||||
mask = torch.rand(vocab_size, generator=self.generator) < 0.99
|
||||
logits[i][mask] = float("-inf")
|
||||
|
||||
k = torch.randint(
|
||||
1, 50, (batch_size,), generator=self.generator, dtype=torch.int32
|
||||
)
|
||||
p = (
|
||||
torch.rand(batch_size, generator=self.generator, dtype=torch.float32) * 0.9
|
||||
+ 0.1
|
||||
)
|
||||
|
||||
result = apply_top_k_top_p_triton(logits.clone(), k, p)
|
||||
assert not result.isnan().any(), "NaN in mixed normal/-inf batch"
|
||||
for i in range(batch_size):
|
||||
kept = (result[i] > float("-inf")).sum().item()
|
||||
assert kept <= k[i].item()
|
||||
finite_in = (logits[i] > float("-inf")).sum().item()
|
||||
if finite_in > 0:
|
||||
assert kept > 0, f"Row {i}: no tokens kept"
|
||||
|
||||
@@ -14,16 +14,12 @@ def cdiv(a: int, b: int) -> int:
|
||||
|
||||
def next_power_of_2(n: int) -> int:
|
||||
"""The next power of 2 (inclusive)"""
|
||||
if n < 1:
|
||||
return 1
|
||||
return 1 << (n - 1).bit_length()
|
||||
return 1 if n < 1 else 1 << (n - 1).bit_length()
|
||||
|
||||
|
||||
def prev_power_of_2(n: int) -> int:
|
||||
"""The previous power of 2 (inclusive)"""
|
||||
if n <= 0:
|
||||
return 0
|
||||
return 1 << (n.bit_length() - 1)
|
||||
return 0 if n <= 0 else 1 << (n.bit_length() - 1)
|
||||
|
||||
|
||||
def round_up(x: int, y: int) -> int:
|
||||
|
||||
@@ -11,6 +11,10 @@ from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm.config.model import LogprobsMode
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -87,8 +91,6 @@ class TopKTopPSampler(nn.Module):
|
||||
else:
|
||||
self.forward = self.forward_native
|
||||
|
||||
self.apply_top_k_top_p = apply_top_k_top_p
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
@@ -101,7 +103,7 @@ class TopKTopPSampler(nn.Module):
|
||||
|
||||
The logits tensor may be updated in-place.
|
||||
"""
|
||||
logits = self.apply_top_k_top_p(logits, k, p)
|
||||
logits = apply_top_k_top_p(logits, k, p)
|
||||
logits_to_return = None
|
||||
if self.logprobs_mode == "processed_logits":
|
||||
logits_to_return = logits
|
||||
@@ -149,7 +151,7 @@ class TopKTopPSampler(nn.Module):
|
||||
|
||||
The logits tensor may be updated in-place.
|
||||
"""
|
||||
logits = self.apply_top_k_top_p(logits, k, p)
|
||||
logits = apply_top_k_top_p_pytorch(logits, k, p, allow_cpu_sync=True)
|
||||
logits_to_return = None
|
||||
if self.logprobs_mode == "processed_logits":
|
||||
logits_to_return = logits
|
||||
@@ -158,14 +160,14 @@ class TopKTopPSampler(nn.Module):
|
||||
|
||||
if len(generators) != logits.shape[0]:
|
||||
return compiled_random_sample(logits), logits_to_return
|
||||
else:
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
q = torch.empty_like(probs)
|
||||
q.exponential_()
|
||||
for i, generator in generators.items():
|
||||
q[i].exponential_(generator=generator)
|
||||
|
||||
return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
q = torch.empty_like(probs)
|
||||
q.exponential_()
|
||||
for i, generator in generators.items():
|
||||
q[i].exponential_(generator=generator)
|
||||
|
||||
return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return
|
||||
|
||||
def forward_hip(
|
||||
self,
|
||||
@@ -241,9 +243,23 @@ def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
def apply_top_k_top_p(
|
||||
logits: torch.Tensor, k: torch.Tensor | None, p: torch.Tensor | None
|
||||
) -> torch.Tensor:
|
||||
if p is None and k is None:
|
||||
return logits
|
||||
|
||||
if HAS_TRITON and logits.shape[0] >= 8:
|
||||
return apply_top_k_top_p_triton(logits, k, p)
|
||||
|
||||
# Use pytorch sort implementation for small batch sizes.
|
||||
return apply_top_k_top_p_pytorch(logits, k, p)
|
||||
|
||||
|
||||
def apply_top_k_top_p_pytorch(
|
||||
logits: torch.Tensor,
|
||||
k: torch.Tensor | None,
|
||||
p: torch.Tensor | None,
|
||||
allow_cpu_sync: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Apply top-k and top-p masks to the logits.
|
||||
|
||||
@@ -256,8 +272,9 @@ def apply_top_k_top_p(
|
||||
if k is None:
|
||||
return logits
|
||||
|
||||
# Avoid sorting vocab for top-k only case.
|
||||
return apply_top_k_only(logits, k)
|
||||
if allow_cpu_sync:
|
||||
# Avoid sorting vocab for top-k only case.
|
||||
return apply_top_k_only(logits, k)
|
||||
|
||||
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
||||
|
||||
@@ -279,18 +296,16 @@ def apply_top_k_top_p(
|
||||
logits_sort.masked_fill_(top_p_mask, -float("inf"))
|
||||
|
||||
# Re-sort the probabilities.
|
||||
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
|
||||
return logits
|
||||
return logits.scatter_(dim=-1, index=logits_idx, src=logits_sort)
|
||||
|
||||
|
||||
def apply_top_k_only(
|
||||
logits: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
def apply_top_k_only(logits: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Apply top-k mask to the logits.
|
||||
|
||||
This implementation doesn't involve sorting the entire vocab.
|
||||
Note however that it involves a GPU->CPU sync which can be detrimental for
|
||||
async scheduling performance.
|
||||
|
||||
The logits tensor may be updated in-place.
|
||||
"""
|
||||
@@ -304,8 +319,7 @@ def apply_top_k_only(
|
||||
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
|
||||
# Handle non-topk rows.
|
||||
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
|
||||
logits.masked_fill_(logits < top_k_mask, -float("inf"))
|
||||
return logits
|
||||
return logits.masked_fill_(logits < top_k_mask, -float("inf"))
|
||||
|
||||
|
||||
def random_sample(
|
||||
|
||||
1039
vllm/v1/sample/ops/topk_topp_triton.py
Normal file
1039
vllm/v1/sample/ops/topk_topp_triton.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user