[Model Runner V2] Use FP32 for Gumbel Noise (#34854)
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
@@ -85,10 +85,10 @@ def _gumbel_sample_kernel(
|
||||
pos = tl.load(pos_ptr + batch_idx)
|
||||
gumbel_seed = tl.randint(seed, pos)
|
||||
|
||||
# Generate gumbel noise.
|
||||
r = tl.rand(gumbel_seed, block).to(tl.float64)
|
||||
gumbel_noise = -tl.log(-tl.log(r + 1e-20) + 1e-20)
|
||||
gumbel_noise = gumbel_noise.to(tl.float32)
|
||||
# Generate gumbel noise in FP32.
|
||||
u = tl.rand(gumbel_seed, block)
|
||||
u = tl.maximum(u, 1e-7)
|
||||
gumbel_noise = -tl.log(-tl.log(u))
|
||||
|
||||
# Apply temperature.
|
||||
if APPLY_TEMPERATURE:
|
||||
@@ -99,18 +99,17 @@ def _gumbel_sample_kernel(
|
||||
# Apply gumbel noise.
|
||||
logits = tl.where(mask, logits + gumbel_noise, float("-inf"))
|
||||
|
||||
idx = tl.argmax(logits, axis=0)
|
||||
value, idx = tl.max(logits, axis=0, return_indices=True)
|
||||
token_id = block_idx * BLOCK_SIZE + idx
|
||||
value = tl.max(logits, axis=0)
|
||||
tl.store(local_argmax_ptr + batch_idx * local_argmax_stride + block_idx, token_id)
|
||||
tl.store(local_max_ptr + batch_idx * local_max_stride + block_idx, value)
|
||||
|
||||
|
||||
def gumbel_sample(
|
||||
logits: torch.Tensor, # [num_reqs, vocab_size]
|
||||
idx_mapping: torch.Tensor, # [num_reqs]
|
||||
temperature: torch.Tensor, # [num_reqs]
|
||||
seed: torch.Tensor, # [num_reqs]
|
||||
idx_mapping: torch.Tensor, # [max_num_reqs]
|
||||
temperature: torch.Tensor, # [max_num_reqs]
|
||||
seed: torch.Tensor, # [max_num_reqs]
|
||||
pos: torch.Tensor, # [num_reqs]
|
||||
apply_temperature: bool,
|
||||
) -> torch.Tensor:
|
||||
|
||||
Reference in New Issue
Block a user