From cd32d6f5868a040430a88c4423e3307116feb433 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Thu, 12 Mar 2026 17:59:23 -0700 Subject: [PATCH] [Model Runner V2] Some code simplification (#36929) Signed-off-by: Nick Hill --- vllm/config/speculative.py | 5 +- vllm/v1/worker/gpu/sample/gumbel.py | 16 +----- .../gpu/spec_decode/rejection_sampler.py | 55 +++++-------------- 3 files changed, 17 insertions(+), 59 deletions(-) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 360f1c32f..ceb82cf90 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -57,10 +57,7 @@ SpeculativeMethod = Literal[ EagleModelTypes, NgramGPUTypes, ] -RejectionSampleMethod = Literal[ - "strict", - "probabilistic", -] +RejectionSampleMethod = Literal["strict", "probabilistic"] @config diff --git a/vllm/v1/worker/gpu/sample/gumbel.py b/vllm/v1/worker/gpu/sample/gumbel.py index 1f10d7bb2..ed7a1dde6 100644 --- a/vllm/v1/worker/gpu/sample/gumbel.py +++ b/vllm/v1/worker/gpu/sample/gumbel.py @@ -81,7 +81,7 @@ def _gumbel_sample_kernel( logits = logits.to(tl.float32) temp = tl.load(temp_ptr + req_state_idx).to(tl.float32) - if (temp != 0.0) and APPLY_TEMPERATURE: + if temp != 0.0 and APPLY_TEMPERATURE: # Apply temperature. # NOTE(woosuk): Match the behavior of _temperature_kernel. # E.g., if the kernel uses tl.div_rn, we should use tl.div_rn here too. @@ -127,18 +127,8 @@ def gumbel_sample( num_tokens, vocab_size = logits.shape BLOCK_SIZE = 1024 num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE) - local_argmax = torch.empty( - num_tokens, - num_blocks, - dtype=torch.int64, - device=logits.device, - ) - local_max = torch.empty( - num_tokens, - num_blocks, - dtype=torch.float32, - device=logits.device, - ) + local_argmax = logits.new_empty(num_tokens, num_blocks, dtype=torch.int64) + local_max = logits.new_empty(num_tokens, num_blocks, dtype=torch.float32) _gumbel_sample_kernel[(num_tokens, num_blocks)]( local_argmax, local_argmax.stride(0), diff --git a/vllm/v1/worker/gpu/spec_decode/rejection_sampler.py b/vllm/v1/worker/gpu/spec_decode/rejection_sampler.py index bd640dab6..c835d86b2 100644 --- a/vllm/v1/worker/gpu/spec_decode/rejection_sampler.py +++ b/vllm/v1/worker/gpu/spec_decode/rejection_sampler.py @@ -53,17 +53,8 @@ def strict_rejection_sample( num_speculative_steps, ) -> tuple[torch.Tensor, torch.Tensor]: num_reqs = cu_num_logits.shape[0] - 1 - sampled = torch.empty( - num_reqs, - num_speculative_steps + 1, - dtype=target_sampled.dtype, - device=target_sampled.device, - ) - num_sampled = torch.empty( - num_reqs, - dtype=torch.int32, - device=target_sampled.device, - ) + sampled = target_sampled.new_empty(num_reqs, num_speculative_steps + 1) + num_sampled = target_sampled.new_empty(num_reqs, dtype=torch.int32) _strict_rejection_sample_kernel[(num_reqs,)]( sampled, sampled.stride(0), @@ -216,12 +207,11 @@ def probabilistic_rejection_sample( pos: torch.Tensor, # [num_reqs] idx_mapping: torch.Tensor, - temperature, - seeds, - num_speculative_steps, + temperature: torch.Tensor, + seed: torch.Tensor, + num_speculative_steps: int, ) -> tuple[torch.Tensor, torch.Tensor]: num_reqs = cu_num_logits.shape[0] - 1 - device = target_logits.device vocab_size = target_logits.shape[-1] # Compute target and draft probs. @@ -230,18 +220,11 @@ def probabilistic_rejection_sample( # Rejection sample. # [num_reqs, num_speculative_steps + 1] - sampled = torch.empty( - num_reqs, - num_speculative_steps + 1, - dtype=torch.int64, - device=device, + sampled = draft_sampled.new_empty( + num_reqs, num_speculative_steps + 1, dtype=torch.int64 ) # [num_reqs] - rejected_steps = torch.empty( - num_reqs, - dtype=torch.int64, - device=device, - ) + rejected_steps = sampled.new_empty(num_reqs) _probabilistic_rejection_sample_kernel[(num_reqs,)]( sampled, sampled.stride(0), @@ -255,25 +238,16 @@ def probabilistic_rejection_sample( cu_num_logits, pos, idx_mapping, - seeds, + seed, num_warps=1, ) # Compute the logits and positions to resample the rejected/bonus # tokens from. # [num_reqs, vocab_size] - residual_logits = torch.empty( - num_reqs, - vocab_size, - dtype=target_logits.dtype, - device=device, - ) + residual_logits = target_logits.new_empty(num_reqs, vocab_size) # [num_reqs] - residual_pos = torch.empty( - num_reqs, - dtype=pos.dtype, - device=device, - ) + residual_pos = pos.new_empty(num_reqs) BLOCK_SIZE = 1024 num_blocks = triton.cdiv(vocab_size, BLOCK_SIZE) _compute_residual_logits_kernel[(num_reqs, num_blocks)]( @@ -299,7 +273,7 @@ def probabilistic_rejection_sample( residual_logits, idx_mapping, temperature, - seeds, + seed, residual_pos, apply_temperature=False, ) @@ -331,10 +305,7 @@ class RejectionSampler: num_nans = get_num_nans(logits) if self.sampler.compute_nans else None if self.use_strict_rejection_sampling: - sampler_output = self.sampler( - logits, - input_batch, - ) + sampler_output = self.sampler(logits, input_batch) logprobs_tensors = sampler_output.logprobs_tensors sampled, num_sampled = strict_rejection_sample( sampler_output.sampled_token_ids.view(-1),