[Model Runner V2] Skip kernel launch for penalties & logit_bias (#32634)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon
2026-01-19 22:20:19 -08:00
committed by GitHub
parent b75e85dede
commit e9c83cdc51
3 changed files with 35 additions and 5 deletions

View File

@@ -1,5 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from vllm.sampling_params import SamplingParams
@@ -49,12 +50,18 @@ class LogitBiasState:
device=device,
)
# Using any of the above.
self.use_logit_bias = np.zeros(max_num_reqs, dtype=bool)
def add_request(
self,
req_idx: int,
prompt_len: int,
sampling_params: SamplingParams,
) -> None:
# Using any logit bias.
use_logit_bias = False
# Allowed token IDs.
allowed_token_ids = sampling_params.allowed_token_ids
if allowed_token_ids:
@@ -66,6 +73,7 @@ class LogitBiasState:
)
self.num_allowed_token_ids.np[req_idx] = num_allowed_token_ids
self.allowed_token_ids.stage_write(req_idx, 0, allowed_token_ids)
use_logit_bias = True
else:
self.num_allowed_token_ids.np[req_idx] = 0
@@ -81,6 +89,7 @@ class LogitBiasState:
self.num_logit_bias.np[req_idx] = num_logit_bias
self.logit_bias_token_ids.stage_write(req_idx, 0, logit_bias.keys())
self.logit_bias.stage_write(req_idx, 0, logit_bias.values())
use_logit_bias = True
else:
self.num_logit_bias.np[req_idx] = 0
@@ -89,7 +98,7 @@ class LogitBiasState:
min_len = prompt_len + min_tokens
self.min_lens.np[req_idx] = min_len
stop_token_ids = sampling_params.all_stop_token_ids
if stop_token_ids:
if min_tokens > 0 and stop_token_ids:
num_stop_token_ids = len(stop_token_ids)
if num_stop_token_ids > MAX_NUM_STOP_TOKEN_IDS:
raise ValueError(
@@ -98,9 +107,12 @@ class LogitBiasState:
)
self.num_stop_token_ids.np[req_idx] = num_stop_token_ids
self.stop_token_ids.stage_write(req_idx, 0, stop_token_ids)
use_logit_bias = True
else:
self.num_stop_token_ids.np[req_idx] = 0
self.use_logit_bias[req_idx] = use_logit_bias
def apply_staged_writes(self) -> None:
self.num_allowed_token_ids.copy_to_uva()
self.allowed_token_ids.apply_write()
@@ -117,8 +129,13 @@ class LogitBiasState:
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
pos: torch.Tensor,
) -> None:
if not np.any(self.use_logit_bias[idx_mapping_np]):
# No request uses logit bias. Skip the kernel launch.
return
apply_logit_bias(
logits,
idx_mapping,

View File

@@ -18,6 +18,7 @@ class PenaltiesState:
self.repetition_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.frequency_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.presence_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.use_penalty = np.zeros(max_num_reqs, dtype=bool)
# Initialize repetition penalty manually because 0 is an invalid value for it.
self.repetition_penalty.np.fill(1.0)
@@ -42,7 +43,10 @@ class PenaltiesState:
self.repetition_penalty.np[req_idx] = sampling_params.repetition_penalty
self.frequency_penalty.np[req_idx] = sampling_params.frequency_penalty
self.presence_penalty.np[req_idx] = sampling_params.presence_penalty
if use_penalty(sampling_params):
do_penalty = use_penalty(sampling_params)
self.use_penalty[req_idx] = do_penalty
if do_penalty:
self._penalties_reqs.append(req_idx)
def apply_staged_writes(
@@ -66,7 +70,16 @@ class PenaltiesState:
self.frequency_penalty.copy_to_uva()
self.presence_penalty.copy_to_uva()
def apply_penalties(self, logits: torch.Tensor, idx_mapping: torch.Tensor) -> None:
def apply_penalties(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
) -> None:
if not np.any(self.use_penalty[idx_mapping_np]):
# No request uses penalties. Skip the kernel launch.
return
apply_penalties(
logits,
idx_mapping,

View File

@@ -104,10 +104,10 @@ class Sampler:
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)
# Apply logit bias (e.g., allowed_token_ids, min_tokens) in place.
self.logit_bias_state.apply_logit_bias(logits, idx_mapping, pos)
self.logit_bias_state.apply_logit_bias(logits, idx_mapping, idx_mapping_np, pos)
# Apply penalties in place.
self.penalties_state.apply_penalties(logits, idx_mapping)
self.penalties_state.apply_penalties(logits, idx_mapping, idx_mapping_np)
# Apply temperature in place.
apply_temperature(logits, idx_mapping, self.sampling_states.temperature.gpu)