From a5e9d511defe2d2dc2dd270674fc197542fc0169 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 22 Mar 2026 12:28:10 -0700 Subject: [PATCH] [MRV2] Use FP64 for Gumbel noise (#37798) Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu/sample/gumbel.py | 24 +++++++++++++++---- .../gpu/spec_decode/rejection_sampler.py | 8 +++---- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/vllm/v1/worker/gpu/sample/gumbel.py b/vllm/v1/worker/gpu/sample/gumbel.py index ed7a1dde6..0d08ceb83 100644 --- a/vllm/v1/worker/gpu/sample/gumbel.py +++ b/vllm/v1/worker/gpu/sample/gumbel.py @@ -49,6 +49,21 @@ def apply_temperature( ) +@triton.jit +def tl_rand64(seed, offset, includes_zero: tl.constexpr): + lo, hi, _, _ = tl.randint4x(seed, offset) + lo = lo.to(tl.uint32, bitcast=True).to(tl.uint64) + hi = hi.to(tl.uint32, bitcast=True).to(tl.uint64) + r = (hi << 32) | lo + + # 1 / 2**64 + scale = 5.421010862427522170037e-20 + u = r.to(tl.float64) * scale + if not includes_zero: + u = tl.maximum(u, 2.2250738585072014e-308) # float64 tiny + return u + + @triton.jit def _gumbel_sample_kernel( local_argmax_ptr, @@ -95,15 +110,16 @@ def _gumbel_sample_kernel( mask=mask, ) + logits = logits.to(tl.float64) if temp != 0.0: # Calculate the seed for gumbel noise. seed = tl.load(seeds_ptr + req_state_idx) pos = tl.load(pos_ptr + token_idx) gumbel_seed = tl.randint(seed, pos) - # Generate gumbel noise in FP32. - u = tl.rand(gumbel_seed, block) - u = tl.maximum(u, 1e-7) + # tl.rand returns fp32, so build a true fp64 uniform from 64 random + # bits before applying the double-log transform. + u = tl_rand64(gumbel_seed, block, includes_zero=False) gumbel_noise = -tl.log(-tl.log(u)) # Apply gumbel noise. @@ -128,7 +144,7 @@ def gumbel_sample( BLOCK_SIZE = 1024 num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE) local_argmax = logits.new_empty(num_tokens, num_blocks, dtype=torch.int64) - local_max = logits.new_empty(num_tokens, num_blocks, dtype=torch.float32) + local_max = logits.new_empty(num_tokens, num_blocks, dtype=torch.float64) _gumbel_sample_kernel[(num_tokens, num_blocks)]( local_argmax, local_argmax.stride(0), diff --git a/vllm/v1/worker/gpu/spec_decode/rejection_sampler.py b/vllm/v1/worker/gpu/spec_decode/rejection_sampler.py index e1f483919..0c6e26aaa 100644 --- a/vllm/v1/worker/gpu/spec_decode/rejection_sampler.py +++ b/vllm/v1/worker/gpu/spec_decode/rejection_sampler.py @@ -6,7 +6,7 @@ from vllm.triton_utils import tl, triton from vllm.v1.outputs import LogprobsTensors from vllm.v1.worker.gpu.input_batch import InputBatch from vllm.v1.worker.gpu.metrics.logits import get_num_nans -from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample +from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample, tl_rand64 from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs from vllm.v1.worker.gpu.sample.output import SamplerOutput from vllm.v1.worker.gpu.sample.sampler import Sampler @@ -211,12 +211,12 @@ def _probabilistic_rejection_kernel( else: target_prob = tl.load( target_probs_ptr + logit_idx * target_probs_stride + draft_sampled - ) + ).to(tl.float64) draft_prob = tl.load( draft_probs_ptr + logit_idx * draft_probs_stride + draft_sampled - ) + ).to(tl.float64) pos = tl.load(pos_ptr + logit_idx) - u = tl.sum(tl.rand(seed, pos + tl.arange(0, 1))) + u = tl_rand64(seed, pos, includes_zero=False) accepted &= target_prob > u * draft_prob tl.store(sampled_ptr + req_idx * sampled_stride + i, draft_sampled) rejected_step += accepted