[Kernel] Triton-based Top-k and Top-p sampler kernels (#33538)
Signed-off-by: js_park <cakeng@naver.com> Signed-off-by: Jongseok Park <37990712+cakeng@users.noreply.github.com> Signed-off-by: Sunga Kim <sunga.kim@berkeley.edu> Signed-off-by: Nick Hill <nickhill123@gmail.com> Co-authored-by: Sunga Kim <sunga.kim@berkeley.edu> Co-authored-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
@@ -5,8 +5,9 @@ import torch
|
||||
from torch import Generator
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
|
||||
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p_pytorch
|
||||
|
||||
CUDA_DEVICE = "cuda" if current_platform.is_cuda() else None
|
||||
DEVICE = current_platform.device_type
|
||||
|
||||
BATCH_SIZE = 1024
|
||||
@@ -39,11 +40,11 @@ def test_topk_impl_equivalence():
|
||||
)
|
||||
|
||||
# Top-k only implementation
|
||||
result1 = apply_top_k_top_p(logits=logits.clone(), k=k, p=None)
|
||||
result1 = apply_top_k_top_p_pytorch(logits=logits.clone(), k=k, p=None)
|
||||
|
||||
# Top-p + top-k
|
||||
no_op_top_p = torch.tensor([1.0])
|
||||
result2 = apply_top_k_top_p(logits=logits.clone(), k=k, p=no_op_top_p)
|
||||
result2 = apply_top_k_top_p_pytorch(logits=logits.clone(), k=k, p=no_op_top_p)
|
||||
|
||||
assert torch.allclose(result1, result2)
|
||||
|
||||
@@ -98,7 +99,7 @@ def test_flashinfer_sampler():
|
||||
torch.randint(0, 2, (BATCH_SIZE,), generator=generator, dtype=torch.bool), 1.0
|
||||
)
|
||||
|
||||
python_logits = apply_top_k_top_p(
|
||||
python_logits = apply_top_k_top_p_pytorch(
|
||||
logits=logits.clone(),
|
||||
k=k_values,
|
||||
p=p_values,
|
||||
@@ -120,3 +121,451 @@ def test_flashinfer_sampler():
|
||||
assert torch.allclose(python_probs, flashinfer_probs, atol=2e-2), (
|
||||
"FlashInfer and Python sampling implementations do not match!"
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Triton kernel tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.skipif(CUDA_DEVICE is None, reason="CUDA not available")
|
||||
class TestTritonTopkTopp:
|
||||
"""Tests for the Triton top-k/top-p kernel."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup(self):
|
||||
"""Set up test fixtures."""
|
||||
torch.set_default_device(CUDA_DEVICE)
|
||||
self.generator = Generator(device=CUDA_DEVICE).manual_seed(42)
|
||||
|
||||
def _compare_results(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
k: torch.Tensor | None,
|
||||
p: torch.Tensor | None,
|
||||
):
|
||||
"""Compare Triton kernel results with PyTorch sorting implementation.
|
||||
|
||||
For top-k only, we expect exact match.
|
||||
For top-p (with or without top-k), we allow small differences due to
|
||||
floating-point precision in probability sum calculations.
|
||||
"""
|
||||
from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton
|
||||
|
||||
# Clone logits for both implementations
|
||||
logits_pytorch = logits.clone()
|
||||
logits_triton = logits.clone().to(torch.float32)
|
||||
|
||||
# Apply PyTorch sorting implementation
|
||||
result_pytorch = apply_top_k_top_p_pytorch(logits_pytorch, k, p)
|
||||
|
||||
# Apply Triton kernel
|
||||
k_i32 = k.to(torch.int32) if k is not None else None
|
||||
p_f32 = p.to(torch.float32) if p is not None else None
|
||||
result_triton = apply_top_k_top_p_triton(logits_triton, k_i32, p_f32)
|
||||
|
||||
# Compare kept counts per row
|
||||
pytorch_kept = (result_pytorch != float("-inf")).sum(dim=-1)
|
||||
triton_kept = (result_triton != float("-inf")).sum(dim=-1)
|
||||
|
||||
if p is None:
|
||||
# Top-k only: expect exact match
|
||||
assert torch.equal(pytorch_kept, triton_kept), (
|
||||
f"Top-k mask mismatch: PyTorch kept {pytorch_kept.tolist()}, "
|
||||
f"Triton kept {triton_kept.tolist()}"
|
||||
)
|
||||
else:
|
||||
# Top-p involved: allow small differences
|
||||
# Either < 1% of kept values OR < 5 values absolute
|
||||
max_diff = (pytorch_kept - triton_kept).abs().max().item()
|
||||
max_kept = pytorch_kept.max().item()
|
||||
if max_kept > 0 and max_diff > 3:
|
||||
diff_pct = max_diff / max_kept * 100
|
||||
assert diff_pct < 0.5, (
|
||||
f"Top-p mask difference too large: {diff_pct:.2f}% "
|
||||
f"(max diff {max_diff} values out of {max_kept})"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128, 512, 1024])
|
||||
@pytest.mark.parametrize("vocab_size", [1024, 32000, 128256])
|
||||
def test_topk_only(self, batch_size: int, vocab_size: int):
|
||||
"""Test top-k only (p=None)."""
|
||||
logits = torch.randn(
|
||||
batch_size, vocab_size, generator=self.generator, dtype=torch.float32
|
||||
)
|
||||
k = torch.randint(
|
||||
1, min(100, vocab_size), (batch_size,), generator=self.generator
|
||||
)
|
||||
# Randomly disable top-k for some rows (~25%)
|
||||
disable_mask = torch.randint(0, 4, (batch_size,), generator=self.generator) == 0
|
||||
k.masked_fill_(disable_mask, vocab_size)
|
||||
|
||||
self._compare_results(logits, k, p=None)
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128, 512, 1024])
|
||||
@pytest.mark.parametrize("vocab_size", [1024, 32000, 128256])
|
||||
def test_topp_only(self, batch_size: int, vocab_size: int):
|
||||
"""Test top-p only (k=None)."""
|
||||
logits = torch.randn(
|
||||
batch_size, vocab_size, generator=self.generator, dtype=torch.float32
|
||||
)
|
||||
p = torch.rand(batch_size, generator=self.generator) * 0.9 + 0.1 # [0.1, 1.0]
|
||||
# Randomly disable top-p for some rows (~25%)
|
||||
disable_mask = torch.randint(0, 4, (batch_size,), generator=self.generator) == 0
|
||||
p.masked_fill_(disable_mask, 1.0)
|
||||
|
||||
self._compare_results(logits, k=None, p=p)
|
||||
|
||||
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128, 512, 1024])
|
||||
@pytest.mark.parametrize("vocab_size", [1024, 32000, 128256])
|
||||
def test_topk_and_topp(self, batch_size: int, vocab_size: int):
|
||||
"""Test combined top-k and top-p."""
|
||||
logits = torch.randn(
|
||||
batch_size, vocab_size, generator=self.generator, dtype=torch.float32
|
||||
)
|
||||
k = torch.randint(
|
||||
1, min(100, vocab_size), (batch_size,), generator=self.generator
|
||||
)
|
||||
p = torch.rand(batch_size, generator=self.generator) * 0.9 + 0.1 # [0.1, 1.0]
|
||||
|
||||
# Randomly disable top-k for some rows (~25%)
|
||||
disable_k = torch.randint(0, 4, (batch_size,), generator=self.generator) == 0
|
||||
k.masked_fill_(disable_k, vocab_size)
|
||||
# Randomly disable top-p for some rows (~25%)
|
||||
disable_p = torch.randint(0, 4, (batch_size,), generator=self.generator) == 0
|
||||
p.masked_fill_(disable_p, 1.0)
|
||||
|
||||
self._compare_results(logits, k, p)
|
||||
|
||||
def test_both_disabled(self):
|
||||
"""Test when both k and p are None (should be no-op)."""
|
||||
from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton
|
||||
|
||||
logits = torch.randn(32, 1024, generator=self.generator, dtype=torch.float32)
|
||||
logits_clone = logits.clone()
|
||||
|
||||
result = apply_top_k_top_p_triton(logits_clone, k=None, p=None)
|
||||
|
||||
assert torch.equal(result, logits), "Should be no-op when both k and p are None"
|
||||
|
||||
def test_extreme_k_values(self):
|
||||
"""Test edge cases for k values."""
|
||||
batch_size, vocab_size = 16, 1024
|
||||
logits = torch.randn(
|
||||
batch_size, vocab_size, generator=self.generator, dtype=torch.float32
|
||||
)
|
||||
|
||||
# k=1 (keep only top 1)
|
||||
k = torch.ones(batch_size, dtype=torch.int32)
|
||||
self._compare_results(logits.clone(), k, p=None)
|
||||
|
||||
# k=vocab_size (keep all)
|
||||
k = torch.full((batch_size,), vocab_size, dtype=torch.int32)
|
||||
self._compare_results(logits.clone(), k, p=None)
|
||||
|
||||
# Mixed extreme values
|
||||
k = torch.tensor([1, vocab_size, 2, vocab_size - 1] * 4, dtype=torch.int32)
|
||||
self._compare_results(logits.clone(), k, p=None)
|
||||
|
||||
def test_extreme_p_values(self):
|
||||
"""Test edge cases for p values."""
|
||||
batch_size, vocab_size = 16, 1024
|
||||
logits = torch.randn(
|
||||
batch_size, vocab_size, generator=self.generator, dtype=torch.float32
|
||||
)
|
||||
|
||||
# p close to 0 (very restrictive)
|
||||
p = torch.full((batch_size,), 0.01, dtype=torch.float32)
|
||||
self._compare_results(logits.clone(), k=None, p=p)
|
||||
|
||||
# p=1.0 (keep all)
|
||||
p = torch.ones(batch_size, dtype=torch.float32)
|
||||
self._compare_results(logits.clone(), k=None, p=p)
|
||||
|
||||
# Mixed values
|
||||
p = torch.tensor([0.1, 0.5, 0.9, 1.0] * 4, dtype=torch.float32)
|
||||
self._compare_results(logits.clone(), k=None, p=p)
|
||||
|
||||
def test_large_batch(self):
|
||||
"""Test with a large batch size."""
|
||||
batch_size, vocab_size = 512, 32000
|
||||
logits = torch.randn(
|
||||
batch_size, vocab_size, generator=self.generator, dtype=torch.float32
|
||||
)
|
||||
k = torch.randint(1, 50, (batch_size,), generator=self.generator)
|
||||
p = torch.rand(batch_size, generator=self.generator) * 0.5 + 0.5
|
||||
|
||||
self._compare_results(logits, k, p)
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# Tests for -inf logits (e.g. from grammar / structured output masks)
|
||||
# -----------------------------------------------------------------
|
||||
|
||||
@pytest.mark.parametrize("inf_fraction", [0.5, 0.9, 0.99])
|
||||
def test_topk_with_neginf_logits(self, inf_fraction: float):
|
||||
"""Top-k with many -inf logits (simulating grammar bitmask).
|
||||
|
||||
The kernel must not produce NaN when most logits are -inf, which
|
||||
can happen when structured-output grammar masks are applied before
|
||||
sampling.
|
||||
"""
|
||||
from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton
|
||||
|
||||
batch_size, vocab_size = 32, 128256
|
||||
logits = torch.randn(
|
||||
batch_size, vocab_size, generator=self.generator, dtype=torch.float32
|
||||
)
|
||||
# Mask a fraction of logits to -inf.
|
||||
mask = (
|
||||
torch.rand(batch_size, vocab_size, generator=self.generator) < inf_fraction
|
||||
)
|
||||
logits[mask] = float("-inf")
|
||||
|
||||
k = torch.randint(
|
||||
1, 50, (batch_size,), generator=self.generator, dtype=torch.int32
|
||||
)
|
||||
result = apply_top_k_top_p_triton(logits.clone(), k, None)
|
||||
|
||||
assert not result.isnan().any(), "NaN found in top-k result with -inf logits"
|
||||
for i in range(batch_size):
|
||||
kept = (result[i] > float("-inf")).sum().item()
|
||||
assert kept <= k[i].item(), f"Row {i}: kept {kept} > k={k[i].item()}"
|
||||
# At least one value should survive unless the row was all -inf.
|
||||
finite_in = (logits[i] > float("-inf")).sum().item()
|
||||
if finite_in > 0:
|
||||
assert kept > 0, f"Row {i}: no tokens kept despite finite input"
|
||||
|
||||
@pytest.mark.parametrize("inf_fraction", [0.5, 0.9, 0.99])
|
||||
def test_topp_with_neginf_logits(self, inf_fraction: float):
|
||||
"""Top-p with many -inf logits."""
|
||||
from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton
|
||||
|
||||
batch_size, vocab_size = 32, 128256
|
||||
logits = torch.randn(
|
||||
batch_size, vocab_size, generator=self.generator, dtype=torch.float32
|
||||
)
|
||||
mask = (
|
||||
torch.rand(batch_size, vocab_size, generator=self.generator) < inf_fraction
|
||||
)
|
||||
logits[mask] = float("-inf")
|
||||
|
||||
p = (
|
||||
torch.rand(batch_size, generator=self.generator, dtype=torch.float32) * 0.9
|
||||
+ 0.1
|
||||
)
|
||||
result = apply_top_k_top_p_triton(logits.clone(), None, p)
|
||||
|
||||
assert not result.isnan().any(), "NaN found in top-p result with -inf logits"
|
||||
for i in range(batch_size):
|
||||
finite_in = (logits[i] > float("-inf")).sum().item()
|
||||
kept = (result[i] > float("-inf")).sum().item()
|
||||
if finite_in > 0:
|
||||
assert kept > 0, f"Row {i}: no tokens kept despite finite input"
|
||||
|
||||
@pytest.mark.parametrize("inf_fraction", [0.5, 0.9, 0.99])
|
||||
def test_topk_topp_with_neginf_logits(self, inf_fraction: float):
|
||||
"""Combined top-k + top-p with many -inf logits."""
|
||||
from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton
|
||||
|
||||
batch_size, vocab_size = 32, 128256
|
||||
logits = torch.randn(
|
||||
batch_size, vocab_size, generator=self.generator, dtype=torch.float32
|
||||
)
|
||||
mask = (
|
||||
torch.rand(batch_size, vocab_size, generator=self.generator) < inf_fraction
|
||||
)
|
||||
logits[mask] = float("-inf")
|
||||
|
||||
k = torch.randint(
|
||||
1, 50, (batch_size,), generator=self.generator, dtype=torch.int32
|
||||
)
|
||||
p = (
|
||||
torch.rand(batch_size, generator=self.generator, dtype=torch.float32) * 0.9
|
||||
+ 0.1
|
||||
)
|
||||
result = apply_top_k_top_p_triton(logits.clone(), k, p)
|
||||
|
||||
assert not result.isnan().any(), (
|
||||
"NaN found in top-k+top-p result with -inf logits"
|
||||
)
|
||||
for i in range(batch_size):
|
||||
kept = (result[i] > float("-inf")).sum().item()
|
||||
assert kept <= k[i].item(), f"Row {i}: kept {kept} > k={k[i].item()}"
|
||||
|
||||
def test_all_neginf_logits(self):
|
||||
"""All logits are -inf (fully masked). Kernel should be a no-op."""
|
||||
from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton
|
||||
|
||||
batch_size, vocab_size = 16, 128256
|
||||
logits = torch.full(
|
||||
(batch_size, vocab_size), float("-inf"), dtype=torch.float32
|
||||
)
|
||||
|
||||
k = torch.randint(
|
||||
1, 50, (batch_size,), generator=self.generator, dtype=torch.int32
|
||||
)
|
||||
p = torch.full((batch_size,), 0.9, dtype=torch.float32)
|
||||
|
||||
# top-k only
|
||||
result = apply_top_k_top_p_triton(logits.clone(), k, None)
|
||||
assert not result.isnan().any(), "NaN from all-inf top-k"
|
||||
assert (result == float("-inf")).all(), "Expected all -inf unchanged"
|
||||
|
||||
# top-p only
|
||||
result = apply_top_k_top_p_triton(logits.clone(), None, p)
|
||||
assert not result.isnan().any(), "NaN from all-inf top-p"
|
||||
assert (result == float("-inf")).all(), "Expected all -inf unchanged"
|
||||
|
||||
# top-k + top-p
|
||||
result = apply_top_k_top_p_triton(logits.clone(), k, p)
|
||||
assert not result.isnan().any(), "NaN from all-inf top-k+top-p"
|
||||
assert (result == float("-inf")).all(), "Expected all -inf unchanged"
|
||||
|
||||
def test_few_valid_tokens_with_neginf(self):
|
||||
"""Only a handful of tokens are finite per row (strict grammar)."""
|
||||
from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton
|
||||
|
||||
batch_size, vocab_size = 32, 128256
|
||||
logits = torch.full(
|
||||
(batch_size, vocab_size), float("-inf"), dtype=torch.float32
|
||||
)
|
||||
# Allow only 5 random tokens per row to be finite.
|
||||
for i in range(batch_size):
|
||||
indices = torch.randperm(vocab_size, generator=self.generator)[:5]
|
||||
logits[i, indices] = torch.randn(
|
||||
5, generator=self.generator, dtype=torch.float32
|
||||
)
|
||||
|
||||
k = torch.full((batch_size,), 50, dtype=torch.int32)
|
||||
p = torch.full((batch_size,), 0.9, dtype=torch.float32)
|
||||
|
||||
# top-k only (k=50 but only 5 finite → keep all 5)
|
||||
result = apply_top_k_top_p_triton(logits.clone(), k, None)
|
||||
assert not result.isnan().any()
|
||||
for i in range(batch_size):
|
||||
kept = (result[i] > float("-inf")).sum().item()
|
||||
assert kept == 5, f"Row {i}: expected 5 kept, got {kept}"
|
||||
|
||||
# top-k with k < num_finite
|
||||
k_small = torch.full((batch_size,), 3, dtype=torch.int32)
|
||||
result = apply_top_k_top_p_triton(logits.clone(), k_small, None)
|
||||
assert not result.isnan().any()
|
||||
for i in range(batch_size):
|
||||
kept = (result[i] > float("-inf")).sum().item()
|
||||
assert kept <= 3, f"Row {i}: expected <=3 kept, got {kept}"
|
||||
|
||||
# top-p only
|
||||
result = apply_top_k_top_p_triton(logits.clone(), None, p)
|
||||
assert not result.isnan().any()
|
||||
for i in range(batch_size):
|
||||
kept = (result[i] > float("-inf")).sum().item()
|
||||
assert kept > 0, f"Row {i}: no tokens kept"
|
||||
|
||||
@pytest.mark.parametrize("num_valid", [1, 2, 5, 10, 50])
|
||||
@pytest.mark.parametrize(
|
||||
"mode",
|
||||
["topk_only", "topp_only", "topk_and_topp"],
|
||||
)
|
||||
def test_equal_logits_few_valid(self, num_valid: int, mode: str):
|
||||
"""Few valid tokens all sharing the same logit value.
|
||||
|
||||
This is the pattern produced by grammar bitmask filtering when
|
||||
the model assigns similar scores to the few allowed tokens.
|
||||
The ternary search can converge to a pivot equal to max_logit,
|
||||
causing the strict `>` keep_mask to exclude everything.
|
||||
Regression test for the `final_pivot >= max_logit` guard.
|
||||
"""
|
||||
from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton
|
||||
|
||||
batch_size, vocab_size = 32, 128256
|
||||
logits = torch.full(
|
||||
(batch_size, vocab_size), float("-inf"), dtype=torch.float32
|
||||
)
|
||||
# Set exactly `num_valid` tokens per row to the SAME finite value.
|
||||
for i in range(batch_size):
|
||||
indices = torch.randperm(vocab_size, generator=self.generator)[:num_valid]
|
||||
logits[i, indices] = 1.0 # all equal
|
||||
|
||||
k: torch.Tensor | None = None
|
||||
p: torch.Tensor | None = None
|
||||
if mode in ("topk_only", "topk_and_topp"):
|
||||
k = torch.full((batch_size,), max(1, num_valid - 1), dtype=torch.int32)
|
||||
if mode in ("topp_only", "topk_and_topp"):
|
||||
p = torch.full((batch_size,), 0.95, dtype=torch.float32)
|
||||
|
||||
result = apply_top_k_top_p_triton(logits.clone(), k, p)
|
||||
|
||||
assert not result.isnan().any(), "NaN in equal-logit result"
|
||||
for i in range(batch_size):
|
||||
kept = (result[i] > float("-inf")).sum().item()
|
||||
# The key invariant: at least one token must survive.
|
||||
# With all-equal logits the pivot search can't differentiate
|
||||
# tokens, so the guard may keep more than k — that is the
|
||||
# intended safe fallback.
|
||||
assert kept > 0, (
|
||||
f"Row {i}: all tokens masked with {num_valid} equal-valued "
|
||||
f"finite logits ({mode})"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("num_valid", [2, 5, 10])
|
||||
def test_nearly_equal_logits_topp(self, num_valid: int):
|
||||
"""Few valid tokens with very similar (but not identical) logits.
|
||||
|
||||
Ensures the kernel handles near-degenerate probability
|
||||
distributions where the ternary search range collapses.
|
||||
"""
|
||||
from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton
|
||||
|
||||
batch_size, vocab_size = 32, 128256
|
||||
logits = torch.full(
|
||||
(batch_size, vocab_size), float("-inf"), dtype=torch.float32
|
||||
)
|
||||
for i in range(batch_size):
|
||||
indices = torch.randperm(vocab_size, generator=self.generator)[:num_valid]
|
||||
# Tiny spread: values in [1.0, 1.0 + 1e-6]
|
||||
logits[i, indices] = (
|
||||
1.0
|
||||
+ torch.rand(num_valid, generator=self.generator, dtype=torch.float32)
|
||||
* 1e-6
|
||||
)
|
||||
|
||||
p = torch.full((batch_size,), 0.95, dtype=torch.float32)
|
||||
result = apply_top_k_top_p_triton(logits.clone(), None, p)
|
||||
|
||||
assert not result.isnan().any(), "NaN in nearly-equal-logit result"
|
||||
for i in range(batch_size):
|
||||
kept = (result[i] > float("-inf")).sum().item()
|
||||
assert kept > 0, (
|
||||
f"Row {i}: all tokens masked with {num_valid} "
|
||||
f"nearly-equal finite logits"
|
||||
)
|
||||
|
||||
def test_mixed_neginf_and_normal_rows(self):
|
||||
"""Batch with a mix of normal rows and heavily-masked rows."""
|
||||
from vllm.v1.sample.ops.topk_topp_triton import apply_top_k_top_p_triton
|
||||
|
||||
batch_size, vocab_size = 32, 32000
|
||||
logits = torch.randn(
|
||||
batch_size, vocab_size, generator=self.generator, dtype=torch.float32
|
||||
)
|
||||
# Mask even rows heavily (99% -inf), leave odd rows normal.
|
||||
for i in range(0, batch_size, 2):
|
||||
mask = torch.rand(vocab_size, generator=self.generator) < 0.99
|
||||
logits[i][mask] = float("-inf")
|
||||
|
||||
k = torch.randint(
|
||||
1, 50, (batch_size,), generator=self.generator, dtype=torch.int32
|
||||
)
|
||||
p = (
|
||||
torch.rand(batch_size, generator=self.generator, dtype=torch.float32) * 0.9
|
||||
+ 0.1
|
||||
)
|
||||
|
||||
result = apply_top_k_top_p_triton(logits.clone(), k, p)
|
||||
assert not result.isnan().any(), "NaN in mixed normal/-inf batch"
|
||||
for i in range(batch_size):
|
||||
kept = (result[i] > float("-inf")).sum().item()
|
||||
assert kept <= k[i].item()
|
||||
finite_in = (logits[i] > float("-inf")).sum().item()
|
||||
if finite_in > 0:
|
||||
assert kept > 0, f"Row {i}: no tokens kept"
|
||||
|
||||
Reference in New Issue
Block a user