[MRV2] Use FP64 for Gumbel noise (#37798)
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
@@ -49,6 +49,21 @@ def apply_temperature(
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def tl_rand64(seed, offset, includes_zero: tl.constexpr):
|
||||
lo, hi, _, _ = tl.randint4x(seed, offset)
|
||||
lo = lo.to(tl.uint32, bitcast=True).to(tl.uint64)
|
||||
hi = hi.to(tl.uint32, bitcast=True).to(tl.uint64)
|
||||
r = (hi << 32) | lo
|
||||
|
||||
# 1 / 2**64
|
||||
scale = 5.421010862427522170037e-20
|
||||
u = r.to(tl.float64) * scale
|
||||
if not includes_zero:
|
||||
u = tl.maximum(u, 2.2250738585072014e-308) # float64 tiny
|
||||
return u
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _gumbel_sample_kernel(
|
||||
local_argmax_ptr,
|
||||
@@ -95,15 +110,16 @@ def _gumbel_sample_kernel(
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
logits = logits.to(tl.float64)
|
||||
if temp != 0.0:
|
||||
# Calculate the seed for gumbel noise.
|
||||
seed = tl.load(seeds_ptr + req_state_idx)
|
||||
pos = tl.load(pos_ptr + token_idx)
|
||||
gumbel_seed = tl.randint(seed, pos)
|
||||
|
||||
# Generate gumbel noise in FP32.
|
||||
u = tl.rand(gumbel_seed, block)
|
||||
u = tl.maximum(u, 1e-7)
|
||||
# tl.rand returns fp32, so build a true fp64 uniform from 64 random
|
||||
# bits before applying the double-log transform.
|
||||
u = tl_rand64(gumbel_seed, block, includes_zero=False)
|
||||
gumbel_noise = -tl.log(-tl.log(u))
|
||||
|
||||
# Apply gumbel noise.
|
||||
@@ -128,7 +144,7 @@ def gumbel_sample(
|
||||
BLOCK_SIZE = 1024
|
||||
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
|
||||
local_argmax = logits.new_empty(num_tokens, num_blocks, dtype=torch.int64)
|
||||
local_max = logits.new_empty(num_tokens, num_blocks, dtype=torch.float32)
|
||||
local_max = logits.new_empty(num_tokens, num_blocks, dtype=torch.float64)
|
||||
_gumbel_sample_kernel[(num_tokens, num_blocks)](
|
||||
local_argmax,
|
||||
local_argmax.stride(0),
|
||||
|
||||
@@ -6,7 +6,7 @@ from vllm.triton_utils import tl, triton
|
||||
from vllm.v1.outputs import LogprobsTensors
|
||||
from vllm.v1.worker.gpu.input_batch import InputBatch
|
||||
from vllm.v1.worker.gpu.metrics.logits import get_num_nans
|
||||
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
|
||||
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample, tl_rand64
|
||||
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
|
||||
from vllm.v1.worker.gpu.sample.output import SamplerOutput
|
||||
from vllm.v1.worker.gpu.sample.sampler import Sampler
|
||||
@@ -211,12 +211,12 @@ def _probabilistic_rejection_kernel(
|
||||
else:
|
||||
target_prob = tl.load(
|
||||
target_probs_ptr + logit_idx * target_probs_stride + draft_sampled
|
||||
)
|
||||
).to(tl.float64)
|
||||
draft_prob = tl.load(
|
||||
draft_probs_ptr + logit_idx * draft_probs_stride + draft_sampled
|
||||
)
|
||||
).to(tl.float64)
|
||||
pos = tl.load(pos_ptr + logit_idx)
|
||||
u = tl.sum(tl.rand(seed, pos + tl.arange(0, 1)))
|
||||
u = tl_rand64(seed, pos, includes_zero=False)
|
||||
accepted &= target_prob > u * draft_prob
|
||||
tl.store(sampled_ptr + req_idx * sampled_stride + i, draft_sampled)
|
||||
rejected_step += accepted
|
||||
|
||||
Reference in New Issue
Block a user