[MRV2] Use FP64 for Gumbel noise (#37798)

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
Woosuk Kwon
2026-03-22 12:28:10 -07:00
committed by GitHub
parent c058ff44d4
commit a5e9d511de
2 changed files with 24 additions and 8 deletions

View File

@@ -49,6 +49,21 @@ def apply_temperature(
)
@triton.jit
def tl_rand64(seed, offset, includes_zero: tl.constexpr):
lo, hi, _, _ = tl.randint4x(seed, offset)
lo = lo.to(tl.uint32, bitcast=True).to(tl.uint64)
hi = hi.to(tl.uint32, bitcast=True).to(tl.uint64)
r = (hi << 32) | lo
# 1 / 2**64
scale = 5.421010862427522170037e-20
u = r.to(tl.float64) * scale
if not includes_zero:
u = tl.maximum(u, 2.2250738585072014e-308) # float64 tiny
return u
@triton.jit
def _gumbel_sample_kernel(
local_argmax_ptr,
@@ -95,15 +110,16 @@ def _gumbel_sample_kernel(
mask=mask,
)
logits = logits.to(tl.float64)
if temp != 0.0:
# Calculate the seed for gumbel noise.
seed = tl.load(seeds_ptr + req_state_idx)
pos = tl.load(pos_ptr + token_idx)
gumbel_seed = tl.randint(seed, pos)
# Generate gumbel noise in FP32.
u = tl.rand(gumbel_seed, block)
u = tl.maximum(u, 1e-7)
# tl.rand returns fp32, so build a true fp64 uniform from 64 random
# bits before applying the double-log transform.
u = tl_rand64(gumbel_seed, block, includes_zero=False)
gumbel_noise = -tl.log(-tl.log(u))
# Apply gumbel noise.
@@ -128,7 +144,7 @@ def gumbel_sample(
BLOCK_SIZE = 1024
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
local_argmax = logits.new_empty(num_tokens, num_blocks, dtype=torch.int64)
local_max = logits.new_empty(num_tokens, num_blocks, dtype=torch.float32)
local_max = logits.new_empty(num_tokens, num_blocks, dtype=torch.float64)
_gumbel_sample_kernel[(num_tokens, num_blocks)](
local_argmax,
local_argmax.stride(0),

View File

@@ -6,7 +6,7 @@ from vllm.triton_utils import tl, triton
from vllm.v1.outputs import LogprobsTensors
from vllm.v1.worker.gpu.input_batch import InputBatch
from vllm.v1.worker.gpu.metrics.logits import get_num_nans
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample, tl_rand64
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm.v1.worker.gpu.sample.sampler import Sampler
@@ -211,12 +211,12 @@ def _probabilistic_rejection_kernel(
else:
target_prob = tl.load(
target_probs_ptr + logit_idx * target_probs_stride + draft_sampled
)
).to(tl.float64)
draft_prob = tl.load(
draft_probs_ptr + logit_idx * draft_probs_stride + draft_sampled
)
).to(tl.float64)
pos = tl.load(pos_ptr + logit_idx)
u = tl.sum(tl.rand(seed, pos + tl.arange(0, 1)))
u = tl_rand64(seed, pos, includes_zero=False)
accepted &= target_prob > u * draft_prob
tl.store(sampled_ptr + req_idx * sampled_stride + i, draft_sampled)
rejected_step += accepted