Fused kernels (zero CPU sync, single kernel launch per projection): - fused_amax_quantize.cu: amax→gsa→quantize in one pass. Replaces two-step compute_amax_gsa_gpu + quantize_nvfp4_gpu (had .item() sync). - fused_deinterleave_amax_quantize.cu: Same for MoE fused_swiglu L2 path. Deinterleave + amax + quantize in one pass. Replaces compute_amax_gsa_gpu + deinterleave_quantize_nvfp4_cuda (had .item() sync). All kernel loaders use dsv4/kernels/cuda/loader.py (compile-once cache). Was JIT-compiling on every call via torch.utils.cpp_extension.load (~100ms/call, ~500 calls/token). Now compiles once and reuses the cached module. Updated layers: - linear.py Nvfp4Linear._run_impl: fused kernel, gsa via GPU buffer - moe.py Nvfp4MoE._run_impl: fused for L1 and L2 (both fused_swiglu and non-fused paths) - shared_expert.py: fused for L1 and L2 - quantize.py: All functions use module loader cache - sampler.py: Uses module loader cache - indexer/score_topk.py: Uses module loader cache P2: Vectorized KVCache.append_swa — index_copy_ instead of Python loop. 2 kernel launches instead of 2T. No .item() in comp_pos either. P3: Pre-allocated comp_kv buffers — O(1) append instead of O(N) torch.cat. max_comp=32768 per layer (32MB). No more quadratic memory growth. ~486 .item() syncs per decoded token → ~0 (only argmax + token decode remain).
164 lines
5.8 KiB
Python
164 lines
5.8 KiB
Python
"""Production token sampler — fused CUDA kernel wrapper.
|
|
|
|
Implements temperature scaling, repetition penalty, top-k, top-p (nucleus) sampling.
|
|
All computation on GPU, zero CPU syncs, CUDA-graph-compatible.
|
|
|
|
Usage:
|
|
sampler = CUDASampler(device='cuda:0')
|
|
token_id = sampler(logits, temperature=0.6, top_k=50, top_p=0.95,
|
|
repetition_penalty=1.1, recent_tokens=token_history)
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
import os
|
|
import torch
|
|
from typing import Optional, List
|
|
|
|
_kernel = None
|
|
|
|
|
|
def _get_kernel():
|
|
global _kernel
|
|
if _kernel is not None:
|
|
return _kernel
|
|
from dsv4.kernels.cuda.loader import get_cuda_module
|
|
_kernel = get_cuda_module("sampler", ["sampler.cu"])
|
|
return _kernel
|
|
|
|
|
|
class CUDASampler:
|
|
"""Production sampler with fused CUDA kernel.
|
|
|
|
All sampling happens on GPU. No .item() calls, no CPU tensors.
|
|
The output is a GPU int64 tensor — the caller can .item() once
|
|
at the end of the decode loop, or keep it on GPU for further processing.
|
|
"""
|
|
|
|
def __init__(self, device: str = 'cuda:0', max_penalty_tokens: int = 256):
|
|
self.device = device
|
|
self.max_penalty_tokens = max_penalty_tokens
|
|
self._penalty_ids_buf = torch.zeros(1, max_penalty_tokens, dtype=torch.int64, device=device)
|
|
self._penalty_vals_buf = torch.ones(1, max_penalty_tokens, dtype=torch.float32, device=device)
|
|
self._step = 0
|
|
|
|
def __call__(
|
|
self,
|
|
logits: torch.Tensor, # (1, vocab_size) or (batch, vocab_size) BF16 or FP32
|
|
temperature: float = 0.6,
|
|
top_k: int = 50,
|
|
top_p: float = 0.95,
|
|
repetition_penalty: float = 1.0,
|
|
min_tokens_to_keep: int = 1,
|
|
recent_tokens: Optional[List[int]] = None, # token IDs for repetition penalty
|
|
seed: Optional[int] = None,
|
|
) -> torch.Tensor: # (batch,) int64 on GPU
|
|
"""Sample tokens from logits using fused CUDA kernel.
|
|
|
|
Returns int64 tensor on GPU. Use .item() to get Python int if needed.
|
|
"""
|
|
if logits.dim() == 1:
|
|
logits = logits.unsqueeze(0)
|
|
assert logits.dim() == 2
|
|
|
|
# Convert to FP32 for the sampler kernel
|
|
logits_f32 = logits.float()
|
|
|
|
batch = logits_f32.shape[0]
|
|
if seed is None:
|
|
seed = 42
|
|
offset = self._step
|
|
self._step += 1
|
|
|
|
# Build repetition penalty buffers
|
|
pen_ids = None
|
|
pen_vals = None
|
|
if repetition_penalty != 1.0 and recent_tokens:
|
|
# Deduplicate and limit
|
|
unique_tokens = list(dict.fromkeys(recent_tokens[-self.max_penalty_tokens:]))
|
|
n_pen = len(unique_tokens)
|
|
if n_pen > 0 and batch <= self._penalty_ids_buf.shape[0]:
|
|
if batch > self._penalty_ids_buf.shape[0]:
|
|
self._penalty_ids_buf = torch.zeros(batch, self.max_penalty_tokens, dtype=torch.int64, device=self.device)
|
|
self._penalty_vals_buf = torch.ones(batch, self.max_penalty_tokens, dtype=torch.float32, device=self.device)
|
|
self._penalty_ids_buf.zero_()
|
|
self._penalty_vals_buf.fill_(1.0)
|
|
for i, tid in enumerate(unique_tokens):
|
|
self._penalty_ids_buf[0, i] = tid
|
|
self._penalty_vals_buf[0, i] = repetition_penalty
|
|
pen_ids = self._penalty_ids_buf[:batch, :n_pen]
|
|
pen_vals = self._penalty_vals_buf[:batch, :n_pen]
|
|
|
|
k = _get_kernel()
|
|
result = k.sample(
|
|
logits_f32,
|
|
pen_ids,
|
|
pen_vals,
|
|
float(temperature),
|
|
int(top_k),
|
|
float(top_p),
|
|
int(min_tokens_to_keep),
|
|
int(seed),
|
|
int(offset),
|
|
)
|
|
return result # (batch,) int64 on GPU
|
|
|
|
|
|
class PyTorchSampler:
|
|
"""Reference sampler using pure PyTorch ops (for correctness verification).
|
|
|
|
Same API as CUDASampler. Used to verify the CUDA kernel produces
|
|
the same distribution.
|
|
"""
|
|
|
|
def __init__(self, device: str = 'cuda:0'):
|
|
self.device = device
|
|
|
|
def __call__(
|
|
self,
|
|
logits: torch.Tensor,
|
|
temperature: float = 0.6,
|
|
top_k: int = 50,
|
|
top_p: float = 0.95,
|
|
repetition_penalty: float = 1.0,
|
|
min_tokens_to_keep: int = 1,
|
|
recent_tokens: Optional[List[int]] = None,
|
|
seed: Optional[int] = None,
|
|
) -> torch.Tensor:
|
|
if logits.dim() == 1:
|
|
logits = logits.unsqueeze(0)
|
|
logits = logits.float().clone()
|
|
|
|
# Repetition penalty
|
|
if repetition_penalty != 1.0 and recent_tokens:
|
|
for tid in set(recent_tokens):
|
|
if 0 <= tid < logits.shape[-1]:
|
|
if logits[0, tid] > 0:
|
|
logits[0, tid] /= repetition_penalty
|
|
else:
|
|
logits[0, tid] *= repetition_penalty
|
|
|
|
# Temperature
|
|
logits = logits / temperature
|
|
|
|
# Top-k
|
|
if top_k > 0:
|
|
top_k = min(top_k, logits.shape[-1])
|
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
|
|
logits[indices_to_remove] = -float('inf')
|
|
|
|
# Top-p (nucleus)
|
|
if top_p < 1.0:
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
|
sorted_indices_to_remove = cumulative_probs - torch.softmax(sorted_logits, dim=-1) >= top_p
|
|
sorted_indices_to_remove[..., :min_tokens_to_keep] = False
|
|
indices_to_remove = sorted_indices_to_remove.scatter(
|
|
1, sorted_indices, sorted_indices_to_remove)
|
|
logits[indices_to_remove] = -float('inf')
|
|
|
|
# Sample
|
|
probs = torch.softmax(logits, dim=-1)
|
|
if seed is not None:
|
|
torch.manual_seed(seed)
|
|
return torch.multinomial(probs, 1).squeeze(-1).to(torch.int64)
|