From 5fcb0cdd68ded0d1c988121d4f744c2424522259 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 18 Feb 2026 17:07:37 -0800 Subject: [PATCH] [Model Runner V2] Use FP32 for Gumbel Noise (#34854) Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu/sample/gumbel.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/vllm/v1/worker/gpu/sample/gumbel.py b/vllm/v1/worker/gpu/sample/gumbel.py index 3a0a6b6a0..84ff3a291 100644 --- a/vllm/v1/worker/gpu/sample/gumbel.py +++ b/vllm/v1/worker/gpu/sample/gumbel.py @@ -85,10 +85,10 @@ def _gumbel_sample_kernel( pos = tl.load(pos_ptr + batch_idx) gumbel_seed = tl.randint(seed, pos) - # Generate gumbel noise. - r = tl.rand(gumbel_seed, block).to(tl.float64) - gumbel_noise = -tl.log(-tl.log(r + 1e-20) + 1e-20) - gumbel_noise = gumbel_noise.to(tl.float32) + # Generate gumbel noise in FP32. + u = tl.rand(gumbel_seed, block) + u = tl.maximum(u, 1e-7) + gumbel_noise = -tl.log(-tl.log(u)) # Apply temperature. if APPLY_TEMPERATURE: @@ -99,18 +99,17 @@ def _gumbel_sample_kernel( # Apply gumbel noise. logits = tl.where(mask, logits + gumbel_noise, float("-inf")) - idx = tl.argmax(logits, axis=0) + value, idx = tl.max(logits, axis=0, return_indices=True) token_id = block_idx * BLOCK_SIZE + idx - value = tl.max(logits, axis=0) tl.store(local_argmax_ptr + batch_idx * local_argmax_stride + block_idx, token_id) tl.store(local_max_ptr + batch_idx * local_max_stride + block_idx, value) def gumbel_sample( logits: torch.Tensor, # [num_reqs, vocab_size] - idx_mapping: torch.Tensor, # [num_reqs] - temperature: torch.Tensor, # [num_reqs] - seed: torch.Tensor, # [num_reqs] + idx_mapping: torch.Tensor, # [max_num_reqs] + temperature: torch.Tensor, # [max_num_reqs] + seed: torch.Tensor, # [max_num_reqs] pos: torch.Tensor, # [num_reqs] apply_temperature: bool, ) -> torch.Tensor: