[Model Runner V2] Minor cleanup for Sampler (#34563)

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
Woosuk Kwon
2026-02-14 18:29:03 -08:00
committed by GitHub
parent d5fe3f702c
commit 9ca768c740
2 changed files with 47 additions and 21 deletions

View File

@@ -7,12 +7,10 @@ import torch
import vllm.envs as envs
from vllm.config.model import LogprobsMode
from vllm.sampling_params import SamplingParams
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.worker.gpu.metrics.logits import get_num_nans
from vllm.v1.worker.gpu.sample.gumbel import apply_temperature, gumbel_sample
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
from vllm.v1.worker.gpu.sample.logit_bias import LogitBiasState
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
from vllm.v1.worker.gpu.sample.min_p import apply_min_p
from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm.v1.worker.gpu.sample.penalties import PenaltiesState
from vllm.v1.worker.gpu.sample.states import NO_LOGPROBS, SamplingStates
@@ -127,20 +125,15 @@ class Sampler:
)
# Apply temperature in place.
apply_temperature(logits, idx_mapping, self.sampling_states.temperature.gpu)
self.sampling_states.apply_temperature(logits, idx_mapping, idx_mapping_np)
# Apply min_p in place if any request has a non-zero min_p.
do_min_p = self.sampling_states.do_min_p(idx_mapping_np)
if do_min_p:
apply_min_p(logits, idx_mapping, self.sampling_states.min_p.gpu)
# Apply min_p in place.
self.sampling_states.apply_min_p(logits, idx_mapping, idx_mapping_np)
# Apply top_k and/or top_p. This might return a new tensor.
do_top_k = self.sampling_states.do_top_k(idx_mapping_np)
top_k = self.sampling_states.top_k.gpu[idx_mapping] if do_top_k else None
do_top_p = self.sampling_states.do_top_p(idx_mapping_np)
top_p = self.sampling_states.top_p.gpu[idx_mapping] if do_top_p else None
if do_top_k or do_top_p:
logits = apply_top_k_top_p(logits, top_k, top_p)
# Apply top_k and/or top_p. This might or might not return a new tensor.
logits = self.sampling_states.apply_top_k_top_p(
logits, idx_mapping, idx_mapping_np
)
# Sample the next token.
sampled = gumbel_sample(

View File

@@ -4,7 +4,10 @@ import numpy as np
import torch
from vllm.sampling_params import SamplingParams
from vllm.v1.sample.ops.topk_topp_sampler import apply_top_k_top_p
from vllm.v1.worker.gpu.buffer_utils import UvaBackedTensor
from vllm.v1.worker.gpu.sample.gumbel import apply_temperature
from vllm.v1.worker.gpu.sample.min_p import apply_min_p
NO_LOGPROBS = -1
_NP_INT64_MIN = np.iinfo(np.int64).min
@@ -58,14 +61,44 @@ class SamplingStates:
self.min_p.copy_to_uva()
self.seeds.copy_to_uva()
def do_min_p(self, idx_mapping_np: np.ndarray) -> bool:
return np.any(self.min_p.np[idx_mapping_np] != 0.0)
def apply_temperature(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
) -> None:
temp_np = self.temperature.np[idx_mapping_np]
if np.all((temp_np == 0.0) | (temp_np == 1.0)):
# No request requires temperature. Skip the kernel launch.
return
def do_top_k(self, idx_mapping_np: np.ndarray) -> bool:
return np.any(self.top_k.np[idx_mapping_np] != self.vocab_size)
apply_temperature(logits, idx_mapping, self.temperature.gpu)
def do_top_p(self, idx_mapping_np: np.ndarray) -> bool:
return np.any(self.top_p.np[idx_mapping_np] != 1.0)
def apply_min_p(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
) -> None:
if np.all(self.min_p.np[idx_mapping_np] == 0.0):
# No request uses min_p. Skip the kernel launch.
return
apply_min_p(logits, idx_mapping, self.min_p.gpu)
def apply_top_k_top_p(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
) -> torch.Tensor:
do_top_k = np.any(self.top_k.np[idx_mapping_np] != self.vocab_size)
do_top_p = np.any(self.top_p.np[idx_mapping_np] != 1.0)
if not (do_top_k or do_top_p):
return logits
top_k = self.top_k.gpu[idx_mapping] if do_top_k else None
top_p = self.top_p.gpu[idx_mapping] if do_top_p else None
return apply_top_k_top_p(logits, top_k, top_p)
def max_num_logprobs(self, idx_mapping_np: np.ndarray) -> int:
return int(np.max(self.num_logprobs[idx_mapping_np]))