[Model Runner V2] Skip kernel launch for penalties & logit_bias (#32634)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
@@ -49,12 +50,18 @@ class LogitBiasState:
|
||||
device=device,
|
||||
)
|
||||
|
||||
# Using any of the above.
|
||||
self.use_logit_bias = np.zeros(max_num_reqs, dtype=bool)
|
||||
|
||||
def add_request(
|
||||
self,
|
||||
req_idx: int,
|
||||
prompt_len: int,
|
||||
sampling_params: SamplingParams,
|
||||
) -> None:
|
||||
# Using any logit bias.
|
||||
use_logit_bias = False
|
||||
|
||||
# Allowed token IDs.
|
||||
allowed_token_ids = sampling_params.allowed_token_ids
|
||||
if allowed_token_ids:
|
||||
@@ -66,6 +73,7 @@ class LogitBiasState:
|
||||
)
|
||||
self.num_allowed_token_ids.np[req_idx] = num_allowed_token_ids
|
||||
self.allowed_token_ids.stage_write(req_idx, 0, allowed_token_ids)
|
||||
use_logit_bias = True
|
||||
else:
|
||||
self.num_allowed_token_ids.np[req_idx] = 0
|
||||
|
||||
@@ -81,6 +89,7 @@ class LogitBiasState:
|
||||
self.num_logit_bias.np[req_idx] = num_logit_bias
|
||||
self.logit_bias_token_ids.stage_write(req_idx, 0, logit_bias.keys())
|
||||
self.logit_bias.stage_write(req_idx, 0, logit_bias.values())
|
||||
use_logit_bias = True
|
||||
else:
|
||||
self.num_logit_bias.np[req_idx] = 0
|
||||
|
||||
@@ -89,7 +98,7 @@ class LogitBiasState:
|
||||
min_len = prompt_len + min_tokens
|
||||
self.min_lens.np[req_idx] = min_len
|
||||
stop_token_ids = sampling_params.all_stop_token_ids
|
||||
if stop_token_ids:
|
||||
if min_tokens > 0 and stop_token_ids:
|
||||
num_stop_token_ids = len(stop_token_ids)
|
||||
if num_stop_token_ids > MAX_NUM_STOP_TOKEN_IDS:
|
||||
raise ValueError(
|
||||
@@ -98,9 +107,12 @@ class LogitBiasState:
|
||||
)
|
||||
self.num_stop_token_ids.np[req_idx] = num_stop_token_ids
|
||||
self.stop_token_ids.stage_write(req_idx, 0, stop_token_ids)
|
||||
use_logit_bias = True
|
||||
else:
|
||||
self.num_stop_token_ids.np[req_idx] = 0
|
||||
|
||||
self.use_logit_bias[req_idx] = use_logit_bias
|
||||
|
||||
def apply_staged_writes(self) -> None:
|
||||
self.num_allowed_token_ids.copy_to_uva()
|
||||
self.allowed_token_ids.apply_write()
|
||||
@@ -117,8 +129,13 @@ class LogitBiasState:
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
pos: torch.Tensor,
|
||||
) -> None:
|
||||
if not np.any(self.use_logit_bias[idx_mapping_np]):
|
||||
# No request uses logit bias. Skip the kernel launch.
|
||||
return
|
||||
|
||||
apply_logit_bias(
|
||||
logits,
|
||||
idx_mapping,
|
||||
|
||||
@@ -18,6 +18,7 @@ class PenaltiesState:
|
||||
self.repetition_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
|
||||
self.frequency_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
|
||||
self.presence_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
|
||||
self.use_penalty = np.zeros(max_num_reqs, dtype=bool)
|
||||
|
||||
# Initialize repetition penalty manually because 0 is an invalid value for it.
|
||||
self.repetition_penalty.np.fill(1.0)
|
||||
@@ -42,7 +43,10 @@ class PenaltiesState:
|
||||
self.repetition_penalty.np[req_idx] = sampling_params.repetition_penalty
|
||||
self.frequency_penalty.np[req_idx] = sampling_params.frequency_penalty
|
||||
self.presence_penalty.np[req_idx] = sampling_params.presence_penalty
|
||||
if use_penalty(sampling_params):
|
||||
|
||||
do_penalty = use_penalty(sampling_params)
|
||||
self.use_penalty[req_idx] = do_penalty
|
||||
if do_penalty:
|
||||
self._penalties_reqs.append(req_idx)
|
||||
|
||||
def apply_staged_writes(
|
||||
@@ -66,7 +70,16 @@ class PenaltiesState:
|
||||
self.frequency_penalty.copy_to_uva()
|
||||
self.presence_penalty.copy_to_uva()
|
||||
|
||||
def apply_penalties(self, logits: torch.Tensor, idx_mapping: torch.Tensor) -> None:
|
||||
def apply_penalties(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
idx_mapping: torch.Tensor,
|
||||
idx_mapping_np: np.ndarray,
|
||||
) -> None:
|
||||
if not np.any(self.use_penalty[idx_mapping_np]):
|
||||
# No request uses penalties. Skip the kernel launch.
|
||||
return
|
||||
|
||||
apply_penalties(
|
||||
logits,
|
||||
idx_mapping,
|
||||
|
||||
@@ -104,10 +104,10 @@ class Sampler:
|
||||
logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits)
|
||||
|
||||
# Apply logit bias (e.g., allowed_token_ids, min_tokens) in place.
|
||||
self.logit_bias_state.apply_logit_bias(logits, idx_mapping, pos)
|
||||
self.logit_bias_state.apply_logit_bias(logits, idx_mapping, idx_mapping_np, pos)
|
||||
|
||||
# Apply penalties in place.
|
||||
self.penalties_state.apply_penalties(logits, idx_mapping)
|
||||
self.penalties_state.apply_penalties(logits, idx_mapping, idx_mapping_np)
|
||||
|
||||
# Apply temperature in place.
|
||||
apply_temperature(logits, idx_mapping, self.sampling_states.temperature.gpu)
|
||||
|
||||
Reference in New Issue
Block a user