[Model Runner V2] Some code simplification (#36929)

Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Nick Hill
2026-03-12 17:59:23 -07:00
committed by GitHub
parent aaa3092f51
commit cd32d6f586
3 changed files with 17 additions and 59 deletions

View File

@@ -57,10 +57,7 @@ SpeculativeMethod = Literal[
EagleModelTypes,
NgramGPUTypes,
]
RejectionSampleMethod = Literal[
"strict",
"probabilistic",
]
RejectionSampleMethod = Literal["strict", "probabilistic"]
@config

View File

@@ -81,7 +81,7 @@ def _gumbel_sample_kernel(
logits = logits.to(tl.float32)
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
if (temp != 0.0) and APPLY_TEMPERATURE:
if temp != 0.0 and APPLY_TEMPERATURE:
# Apply temperature.
# NOTE(woosuk): Match the behavior of _temperature_kernel.
# E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
@@ -127,18 +127,8 @@ def gumbel_sample(
num_tokens, vocab_size = logits.shape
BLOCK_SIZE = 1024
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
local_argmax = torch.empty(
num_tokens,
num_blocks,
dtype=torch.int64,
device=logits.device,
)
local_max = torch.empty(
num_tokens,
num_blocks,
dtype=torch.float32,
device=logits.device,
)
local_argmax = logits.new_empty(num_tokens, num_blocks, dtype=torch.int64)
local_max = logits.new_empty(num_tokens, num_blocks, dtype=torch.float32)
_gumbel_sample_kernel[(num_tokens, num_blocks)](
local_argmax,
local_argmax.stride(0),

View File

@@ -53,17 +53,8 @@ def strict_rejection_sample(
num_speculative_steps,
) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = cu_num_logits.shape[0] - 1
sampled = torch.empty(
num_reqs,
num_speculative_steps + 1,
dtype=target_sampled.dtype,
device=target_sampled.device,
)
num_sampled = torch.empty(
num_reqs,
dtype=torch.int32,
device=target_sampled.device,
)
sampled = target_sampled.new_empty(num_reqs, num_speculative_steps + 1)
num_sampled = target_sampled.new_empty(num_reqs, dtype=torch.int32)
_strict_rejection_sample_kernel[(num_reqs,)](
sampled,
sampled.stride(0),
@@ -216,12 +207,11 @@ def probabilistic_rejection_sample(
pos: torch.Tensor,
# [num_reqs]
idx_mapping: torch.Tensor,
temperature,
seeds,
num_speculative_steps,
temperature: torch.Tensor,
seed: torch.Tensor,
num_speculative_steps: int,
) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = cu_num_logits.shape[0] - 1
device = target_logits.device
vocab_size = target_logits.shape[-1]
# Compute target and draft probs.
@@ -230,18 +220,11 @@ def probabilistic_rejection_sample(
# Rejection sample.
# [num_reqs, num_speculative_steps + 1]
sampled = torch.empty(
num_reqs,
num_speculative_steps + 1,
dtype=torch.int64,
device=device,
sampled = draft_sampled.new_empty(
num_reqs, num_speculative_steps + 1, dtype=torch.int64
)
# [num_reqs]
rejected_steps = torch.empty(
num_reqs,
dtype=torch.int64,
device=device,
)
rejected_steps = sampled.new_empty(num_reqs)
_probabilistic_rejection_sample_kernel[(num_reqs,)](
sampled,
sampled.stride(0),
@@ -255,25 +238,16 @@ def probabilistic_rejection_sample(
cu_num_logits,
pos,
idx_mapping,
seeds,
seed,
num_warps=1,
)
# Compute the logits and positions to resample the rejected/bonus
# tokens from.
# [num_reqs, vocab_size]
residual_logits = torch.empty(
num_reqs,
vocab_size,
dtype=target_logits.dtype,
device=device,
)
residual_logits = target_logits.new_empty(num_reqs, vocab_size)
# [num_reqs]
residual_pos = torch.empty(
num_reqs,
dtype=pos.dtype,
device=device,
)
residual_pos = pos.new_empty(num_reqs)
BLOCK_SIZE = 1024
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
_compute_residual_logits_kernel[(num_reqs, num_blocks)](
@@ -299,7 +273,7 @@ def probabilistic_rejection_sample(
residual_logits,
idx_mapping,
temperature,
seeds,
seed,
residual_pos,
apply_temperature=False,
)
@@ -331,10 +305,7 @@ class RejectionSampler:
num_nans = get_num_nans(logits) if self.sampler.compute_nans else None
if self.use_strict_rejection_sampling:
sampler_output = self.sampler(
logits,
input_batch,
)
sampler_output = self.sampler(logits, input_batch)
logprobs_tensors = sampler_output.logprobs_tensors
sampled, num_sampled = strict_rejection_sample(
sampler_output.sampled_token_ids.view(-1),