[BugFix][Spec Decode] Use float64 for uniform_probs (#23803)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2025-08-28 05:26:45 -07:00
committed by GitHub
parent 67cee40da0
commit a3432f18fd
2 changed files with 7 additions and 2 deletions

View File

@@ -365,9 +365,14 @@ def generate_uniform_probs(
A tensor of shape `(num_tokens, )` containing uniform
random values in the range [0, 1).
"""
# NOTE(woosuk): We deliberately use float64 instead of float32 here
# because when using float32, there's a non-negligible chance that
# uniform_prob is sampled to be exact 0.0 as reported in
# https://github.com/pytorch/pytorch/issues/16706. Using float64
# mitigates the issue.
uniform_probs = torch.rand(
(num_tokens, ),
dtype=torch.float32,
dtype=torch.float64,
device=device,
)
start_idx = 0