[Model Runner V2] Some code simplification (#36929)

Signed-off-by: Nick Hill <nickhill123@gmail.com>
This commit is contained in:
Nick Hill
2026-03-12 17:59:23 -07:00
committed by GitHub
parent aaa3092f51
commit cd32d6f586
3 changed files with 17 additions and 59 deletions

View File

@@ -57,10 +57,7 @@ SpeculativeMethod = Literal[
EagleModelTypes, EagleModelTypes,
NgramGPUTypes, NgramGPUTypes,
] ]
RejectionSampleMethod = Literal[ RejectionSampleMethod = Literal["strict", "probabilistic"]
"strict",
"probabilistic",
]
@config @config

View File

@@ -81,7 +81,7 @@ def _gumbel_sample_kernel(
logits = logits.to(tl.float32) logits = logits.to(tl.float32)
temp = tl.load(temp_ptr + req_state_idx).to(tl.float32) temp = tl.load(temp_ptr + req_state_idx).to(tl.float32)
if (temp != 0.0) and APPLY_TEMPERATURE: if temp != 0.0 and APPLY_TEMPERATURE:
# Apply temperature. # Apply temperature.
# NOTE(woosuk): Match the behavior of _temperature_kernel. # NOTE(woosuk): Match the behavior of _temperature_kernel.
# E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too. # E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too.
@@ -127,18 +127,8 @@ def gumbel_sample(
num_tokens, vocab_size = logits.shape num_tokens, vocab_size = logits.shape
BLOCK_SIZE = 1024 BLOCK_SIZE = 1024
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE) num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
local_argmax = torch.empty( local_argmax = logits.new_empty(num_tokens, num_blocks, dtype=torch.int64)
num_tokens, local_max = logits.new_empty(num_tokens, num_blocks, dtype=torch.float32)
num_blocks,
dtype=torch.int64,
device=logits.device,
)
local_max = torch.empty(
num_tokens,
num_blocks,
dtype=torch.float32,
device=logits.device,
)
_gumbel_sample_kernel[(num_tokens, num_blocks)]( _gumbel_sample_kernel[(num_tokens, num_blocks)](
local_argmax, local_argmax,
local_argmax.stride(0), local_argmax.stride(0),

View File

@@ -53,17 +53,8 @@ def strict_rejection_sample(
num_speculative_steps, num_speculative_steps,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = cu_num_logits.shape[0] - 1 num_reqs = cu_num_logits.shape[0] - 1
sampled = torch.empty( sampled = target_sampled.new_empty(num_reqs, num_speculative_steps + 1)
num_reqs, num_sampled = target_sampled.new_empty(num_reqs, dtype=torch.int32)
num_speculative_steps + 1,
dtype=target_sampled.dtype,
device=target_sampled.device,
)
num_sampled = torch.empty(
num_reqs,
dtype=torch.int32,
device=target_sampled.device,
)
_strict_rejection_sample_kernel[(num_reqs,)]( _strict_rejection_sample_kernel[(num_reqs,)](
sampled, sampled,
sampled.stride(0), sampled.stride(0),
@@ -216,12 +207,11 @@ def probabilistic_rejection_sample(
pos: torch.Tensor, pos: torch.Tensor,
# [num_reqs] # [num_reqs]
idx_mapping: torch.Tensor, idx_mapping: torch.Tensor,
temperature, temperature: torch.Tensor,
seeds, seed: torch.Tensor,
num_speculative_steps, num_speculative_steps: int,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
num_reqs = cu_num_logits.shape[0] - 1 num_reqs = cu_num_logits.shape[0] - 1
device = target_logits.device
vocab_size = target_logits.shape[-1] vocab_size = target_logits.shape[-1]
# Compute target and draft probs. # Compute target and draft probs.
@@ -230,18 +220,11 @@ def probabilistic_rejection_sample(
# Rejection sample. # Rejection sample.
# [num_reqs, num_speculative_steps + 1] # [num_reqs, num_speculative_steps + 1]
sampled = torch.empty( sampled = draft_sampled.new_empty(
num_reqs, num_reqs, num_speculative_steps + 1, dtype=torch.int64
num_speculative_steps + 1,
dtype=torch.int64,
device=device,
) )
# [num_reqs] # [num_reqs]
rejected_steps = torch.empty( rejected_steps = sampled.new_empty(num_reqs)
num_reqs,
dtype=torch.int64,
device=device,
)
_probabilistic_rejection_sample_kernel[(num_reqs,)]( _probabilistic_rejection_sample_kernel[(num_reqs,)](
sampled, sampled,
sampled.stride(0), sampled.stride(0),
@@ -255,25 +238,16 @@ def probabilistic_rejection_sample(
cu_num_logits, cu_num_logits,
pos, pos,
idx_mapping, idx_mapping,
seeds, seed,
num_warps=1, num_warps=1,
) )
# Compute the logits and positions to resample the rejected/bonus # Compute the logits and positions to resample the rejected/bonus
# tokens from. # tokens from.
# [num_reqs, vocab_size] # [num_reqs, vocab_size]
residual_logits = torch.empty( residual_logits = target_logits.new_empty(num_reqs, vocab_size)
num_reqs,
vocab_size,
dtype=target_logits.dtype,
device=device,
)
# [num_reqs] # [num_reqs]
residual_pos = torch.empty( residual_pos = pos.new_empty(num_reqs)
num_reqs,
dtype=pos.dtype,
device=device,
)
BLOCK_SIZE = 1024 BLOCK_SIZE = 1024
num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE) num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE)
_compute_residual_logits_kernel[(num_reqs, num_blocks)]( _compute_residual_logits_kernel[(num_reqs, num_blocks)](
@@ -299,7 +273,7 @@ def probabilistic_rejection_sample(
residual_logits, residual_logits,
idx_mapping, idx_mapping,
temperature, temperature,
seeds, seed,
residual_pos, residual_pos,
apply_temperature=False, apply_temperature=False,
) )
@@ -331,10 +305,7 @@ class RejectionSampler:
num_nans = get_num_nans(logits) if self.sampler.compute_nans else None num_nans = get_num_nans(logits) if self.sampler.compute_nans else None
if self.use_strict_rejection_sampling: if self.use_strict_rejection_sampling:
sampler_output = self.sampler( sampler_output = self.sampler(logits, input_batch)
logits,
input_batch,
)
logprobs_tensors = sampler_output.logprobs_tensors logprobs_tensors = sampler_output.logprobs_tensors
sampled, num_sampled = strict_rejection_sample( sampled, num_sampled = strict_rejection_sample(
sampler_output.sampled_token_ids.view(-1), sampler_output.sampled_token_ids.view(-1),