Files
nvfp4-megamoe-kernel/dsv4/model/sampler.py
biondizzle 4f698baa5d Production fused CUDA sampler + decode loop optimizations
- Add dsv4/kernels/cuda/sampler.cu: fused temperature + repetition penalty
  + top-k + top-p (nucleus) sampling, single kernel launch, zero CPU syncs
- Add dsv4/model/sampler.py: CUDASampler wrapper + PyTorch reference
- Update single_shot_inference.py:
  - Use CUDASampler for non-greedy decoding (temperature=0.6, top_k=50, top_p=0.95)
  - Pre-allocate decode buffers (no per-step torch.tensor allocation)
  - Track thinking tokens (128821/128822) — not garbage for reasoning model
  - Reduce diagnostic CPU syncs (top-5 every 5 steps, NaN check every 20)
  - Add --top-k and --top-p CLI args
  - Default: temperature=0.6 (was 0.0 greedy), rep_penalty=1.1 (was 1.2)
2026-06-01 20:29:57 +00:00

170 lines
6.1 KiB
Python

"""Production token sampler — fused CUDA kernel wrapper.
Implements temperature scaling, repetition penalty, top-k, top-p (nucleus) sampling.
All computation on GPU, zero CPU syncs, CUDA-graph-compatible.
Usage:
sampler = CUDASampler(device='cuda:0')
token_id = sampler(logits, temperature=0.6, top_k=50, top_p=0.95,
repetition_penalty=1.1, recent_tokens=token_history)
"""
from __future__ import annotations
import os
import torch
from typing import Optional, List
_kernel = None
def _get_kernel():
global _kernel
if _kernel is not None:
return _kernel
from torch.utils.cpp_extension import load
kdir = os.path.join(os.path.dirname(__file__), "..", "kernels", "cuda")
_kernel = load(
name="dsv4_sampler",
sources=[os.path.join(kdir, "sampler.cu")],
extra_cuda_cflags=["-O3", "--generate-code=arch=compute_100a,code=[sm_100a]"],
verbose=False,
)
return _kernel
class CUDASampler:
"""Production sampler with fused CUDA kernel.
All sampling happens on GPU. No .item() calls, no CPU tensors.
The output is a GPU int64 tensor — the caller can .item() once
at the end of the decode loop, or keep it on GPU for further processing.
"""
def __init__(self, device: str = 'cuda:0', max_penalty_tokens: int = 256):
self.device = device
self.max_penalty_tokens = max_penalty_tokens
self._penalty_ids_buf = torch.zeros(1, max_penalty_tokens, dtype=torch.int64, device=device)
self._penalty_vals_buf = torch.ones(1, max_penalty_tokens, dtype=torch.float32, device=device)
self._step = 0
def __call__(
self,
logits: torch.Tensor, # (1, vocab_size) or (batch, vocab_size) BF16 or FP32
temperature: float = 0.6,
top_k: int = 50,
top_p: float = 0.95,
repetition_penalty: float = 1.0,
min_tokens_to_keep: int = 1,
recent_tokens: Optional[List[int]] = None, # token IDs for repetition penalty
seed: Optional[int] = None,
) -> torch.Tensor: # (batch,) int64 on GPU
"""Sample tokens from logits using fused CUDA kernel.
Returns int64 tensor on GPU. Use .item() to get Python int if needed.
"""
if logits.dim() == 1:
logits = logits.unsqueeze(0)
assert logits.dim() == 2
# Convert to FP32 for the sampler kernel
logits_f32 = logits.float()
batch = logits_f32.shape[0]
if seed is None:
seed = 42
offset = self._step
self._step += 1
# Build repetition penalty buffers
pen_ids = None
pen_vals = None
if repetition_penalty != 1.0 and recent_tokens:
# Deduplicate and limit
unique_tokens = list(dict.fromkeys(recent_tokens[-self.max_penalty_tokens:]))
n_pen = len(unique_tokens)
if n_pen > 0 and batch <= self._penalty_ids_buf.shape[0]:
if batch > self._penalty_ids_buf.shape[0]:
self._penalty_ids_buf = torch.zeros(batch, self.max_penalty_tokens, dtype=torch.int64, device=self.device)
self._penalty_vals_buf = torch.ones(batch, self.max_penalty_tokens, dtype=torch.float32, device=self.device)
self._penalty_ids_buf.zero_()
self._penalty_vals_buf.fill_(1.0)
for i, tid in enumerate(unique_tokens):
self._penalty_ids_buf[0, i] = tid
self._penalty_vals_buf[0, i] = repetition_penalty
pen_ids = self._penalty_ids_buf[:batch, :n_pen]
pen_vals = self._penalty_vals_buf[:batch, :n_pen]
k = _get_kernel()
result = k.sample(
logits_f32,
pen_ids,
pen_vals,
float(temperature),
int(top_k),
float(top_p),
int(min_tokens_to_keep),
int(seed),
int(offset),
)
return result # (batch,) int64 on GPU
class PyTorchSampler:
"""Reference sampler using pure PyTorch ops (for correctness verification).
Same API as CUDASampler. Used to verify the CUDA kernel produces
the same distribution.
"""
def __init__(self, device: str = 'cuda:0'):
self.device = device
def __call__(
self,
logits: torch.Tensor,
temperature: float = 0.6,
top_k: int = 50,
top_p: float = 0.95,
repetition_penalty: float = 1.0,
min_tokens_to_keep: int = 1,
recent_tokens: Optional[List[int]] = None,
seed: Optional[int] = None,
) -> torch.Tensor:
if logits.dim() == 1:
logits = logits.unsqueeze(0)
logits = logits.float().clone()
# Repetition penalty
if repetition_penalty != 1.0 and recent_tokens:
for tid in set(recent_tokens):
if 0 <= tid < logits.shape[-1]:
if logits[0, tid] > 0:
logits[0, tid] /= repetition_penalty
else:
logits[0, tid] *= repetition_penalty
# Temperature
logits = logits / temperature
# Top-k
if top_k > 0:
top_k = min(top_k, logits.shape[-1])
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = -float('inf')
# Top-p (nucleus)
if top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cumulative_probs - torch.softmax(sorted_logits, dim=-1) >= top_p
sorted_indices_to_remove[..., :min_tokens_to_keep] = False
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = -float('inf')
# Sample
probs = torch.softmax(logits, dim=-1)
if seed is not None:
torch.manual_seed(seed)
return torch.multinomial(probs, 1).squeeze(-1).to(torch.int64)