[Model Runner V2] Minor refactoring for penalties (#34662)

Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
Woosuk Kwon
2026-02-16 21:43:00 -08:00
committed by GitHub
parent 9752da9d9c
commit d00df624f3
4 changed files with 94 additions and 75 deletions

View File

@@ -155,9 +155,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
max_num_reqs=self.max_num_reqs,
vocab_size=self.vocab_size,
device=self.device,
all_token_ids=self.req_states.all_token_ids.gpu,
prompt_len=self.req_states.prompt_len.gpu,
total_len=self.req_states.total_len.gpu,
req_states=self.req_states,
logprobs_mode=self.model_config.logprobs_mode,
num_speculative_tokens=self.num_speculative_steps + 1,
)
@@ -528,11 +526,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if scheduler_output.scheduled_new_reqs:
self.req_states.apply_staged_writes()
self.sampler.apply_staged_writes(
self.req_states.all_token_ids.gpu,
self.req_states.prefill_len.np,
self.req_states.prompt_len.np,
)
self.sampler.apply_staged_writes()
if self.uses_mrope:
self.mrope_states.apply_staged_writes()

View File

@@ -6,24 +6,17 @@ import torch
from vllm.sampling_params import SamplingParams
from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
from vllm.v1.worker.gpu.states import RequestState
MAX_BAD_WORDS_TOTAL_TOKENS = 1024 # Max total tokens for all bad words per request
MAX_NUM_BAD_WORDS = 128 # Max number of bad words per request
class BadWordsState:
def __init__(
self,
all_token_ids: torch.Tensor,
prompt_len: torch.Tensor,
total_len: torch.Tensor,
):
self.all_token_ids = all_token_ids
self.prompt_len = prompt_len
self.total_len = total_len
self.max_num_reqs = prompt_len.shape[0]
self.device = prompt_len.device
def __init__(self, req_states: RequestState):
self.req_states = req_states
self.max_num_reqs = req_states.max_num_reqs
self.device = req_states.device
# flattened bad word tokens: [max_num_reqs, MAX_BAD_WORDS_TOTAL_TOKENS]
self.bad_word_token_ids = StagedWriteTensor(
@@ -95,9 +88,9 @@ class BadWordsState:
self.bad_word_token_ids.gpu,
self.bad_word_offsets.gpu,
self.num_bad_words.gpu,
self.all_token_ids,
self.prompt_len,
self.total_len,
self.req_states.all_token_ids.gpu,
self.req_states.prompt_len.gpu,
self.req_states.total_len.gpu,
input_ids,
expanded_local_pos,
max_num_bad_words,

View File

@@ -6,14 +6,18 @@ import torch
from vllm.sampling_params import SamplingParams
from vllm.triton_utils import tl, triton
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import async_tensor_h2d
from vllm.v1.worker.gpu.buffer_utils import UvaBackedTensor
from vllm.v1.worker.gpu.states import RequestState
class PenaltiesState:
def __init__(self, max_num_reqs: int, vocab_size: int, device: torch.device):
self.max_num_reqs = max_num_reqs
self.vocab_size = vocab_size
self.device = device
def __init__(self, req_states: RequestState):
self.req_states = req_states
max_num_reqs = req_states.max_num_reqs
self.vocab_size = req_states.vocab_size
self.device = req_states.device
self.repetition_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
self.frequency_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32)
@@ -26,7 +30,7 @@ class PenaltiesState:
# Statistics for penalties.
self.prompt_bin_mask = torch.zeros(
self.max_num_reqs,
max_num_reqs,
cdiv(self.vocab_size, 32),
dtype=torch.int32,
device=self.device,
@@ -34,10 +38,10 @@ class PenaltiesState:
# TODO(woosuk): This tensor is rarely used but can be very large, taking up
# GBs of GPU memory. Optimize the memory usage.
self.output_bin_counts = torch.zeros(
self.max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device
max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device
)
self._penalties_reqs: list[int] = []
self._new_penalties_reqs: list[int] = []
def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None:
self.repetition_penalty.np[req_idx] = sampling_params.repetition_penalty
@@ -47,24 +51,29 @@ class PenaltiesState:
do_penalty = use_penalty(sampling_params)
self.use_penalty[req_idx] = do_penalty
if do_penalty:
self._penalties_reqs.append(req_idx)
self._new_penalties_reqs.append(req_idx)
def apply_staged_writes(
self,
all_token_ids: torch.Tensor,
prefill_lens: np.ndarray,
prompt_lens: np.ndarray,
) -> None:
# TODO(woosuk): Optimize this.
for req_idx in self._penalties_reqs:
bincount(
all_token_ids[req_idx],
int(prefill_lens[req_idx]),
int(prompt_lens[req_idx]),
self.prompt_bin_mask[req_idx],
self.output_bin_counts[req_idx],
def apply_staged_writes(self) -> None:
if self._new_penalties_reqs:
idx_mapping = async_tensor_h2d(
self._new_penalties_reqs,
dtype=torch.int32,
target_device=self.device,
pin_memory=True,
)
self._penalties_reqs.clear()
prefill_lens = self.req_states.prefill_len.np[self._new_penalties_reqs]
max_prefill_len = int(prefill_lens.max())
bincount(
idx_mapping,
self.req_states.all_token_ids.gpu,
self.req_states.prompt_len.gpu,
self.req_states.prefill_len.gpu,
self.prompt_bin_mask,
self.output_bin_counts,
max_prefill_len,
)
self._new_penalties_reqs.clear()
self.repetition_penalty.copy_to_uva()
self.frequency_penalty.copy_to_uva()
@@ -214,51 +223,82 @@ def apply_penalties(
)
@triton.jit(do_not_specialize=["prefill_len", "prompt_len"])
@triton.jit
def _bincount_kernel(
idx_mapping_ptr,
all_token_ids_ptr,
prefill_len,
prompt_len,
all_token_ids_stride,
prompt_len_ptr,
prefill_len_ptr,
prompt_bin_mask_ptr,
prompt_bin_mask_stride,
output_bin_counts_ptr,
output_bin_counts_stride,
BLOCK_SIZE: tl.constexpr,
):
block_idx = tl.program_id(0)
batch_idx = tl.program_id(0)
block_idx = tl.program_id(1)
req_state_idx = tl.load(idx_mapping_ptr + batch_idx)
prefill_len = tl.load(prefill_len_ptr + req_state_idx)
if block_idx * BLOCK_SIZE >= prefill_len:
return
prompt_len = tl.load(prompt_len_ptr + req_state_idx)
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
if block_idx * BLOCK_SIZE < prompt_len:
mask = block < prompt_len
prompt_tokens = tl.load(all_token_ids_ptr + block, mask=mask)
prompt_tokens = tl.load(
all_token_ids_ptr + req_state_idx * all_token_ids_stride + block, mask=mask
)
idx = prompt_tokens // 32
bit_idx = prompt_tokens % 32
bit = tl.full((BLOCK_SIZE,), 1, tl.int32) << bit_idx
tl.atomic_or(prompt_bin_mask_ptr + idx, bit, mask=mask)
tl.atomic_or(
prompt_bin_mask_ptr + req_state_idx * prompt_bin_mask_stride + idx,
bit,
mask=mask,
)
if (block_idx + 1) * BLOCK_SIZE >= prompt_len:
mask = block < prefill_len
mask &= block >= prompt_len
output_tokens = tl.load(all_token_ids_ptr + block, mask=mask)
tl.atomic_add(output_bin_counts_ptr + output_tokens, 1, mask=mask)
output_tokens = tl.load(
all_token_ids_ptr + req_state_idx * all_token_ids_stride + block, mask=mask
)
tl.atomic_add(
output_bin_counts_ptr
+ req_state_idx * output_bin_counts_stride
+ output_tokens,
1,
mask=mask,
)
def bincount(
idx_mapping: torch.Tensor,
all_token_ids: torch.Tensor,
prefill_len: int,
prompt_len: int,
prompt_len: torch.Tensor,
prefill_len: torch.Tensor,
prompt_bin_mask: torch.Tensor,
output_bin_counts: torch.Tensor,
max_prefill_len: int,
) -> None:
prompt_bin_mask.zero_()
output_bin_counts.zero_()
prompt_bin_mask[idx_mapping] = 0
output_bin_counts[idx_mapping] = 0
num_reqs = idx_mapping.shape[0]
BLOCK_SIZE = 1024
num_blocks = triton.cdiv(prefill_len, BLOCK_SIZE)
_bincount_kernel[(num_blocks,)](
num_blocks = triton.cdiv(max_prefill_len, BLOCK_SIZE)
_bincount_kernel[(num_reqs, num_blocks)](
idx_mapping,
all_token_ids,
prefill_len,
all_token_ids.stride(0),
prompt_len,
prefill_len,
prompt_bin_mask,
prompt_bin_mask.stride(0),
output_bin_counts,
output_bin_counts.stride(0),
BLOCK_SIZE=BLOCK_SIZE,
)

View File

@@ -15,6 +15,7 @@ from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
from vllm.v1.worker.gpu.sample.output import SamplerOutput
from vllm.v1.worker.gpu.sample.penalties import PenaltiesState
from vllm.v1.worker.gpu.sample.states import NO_LOGPROBS, SamplingStates
from vllm.v1.worker.gpu.states import RequestState
class Sampler:
@@ -23,9 +24,7 @@ class Sampler:
max_num_reqs: int,
vocab_size: int,
device: torch.device,
all_token_ids: torch.Tensor,
prompt_len: torch.Tensor,
total_len: torch.Tensor,
req_states: RequestState,
logprobs_mode: LogprobsMode = "raw_logprobs",
num_speculative_tokens: int = 1,
):
@@ -35,9 +34,9 @@ class Sampler:
self.compute_nans = envs.VLLM_COMPUTE_NANS_IN_LOGITS # False by default.
self.sampling_states = SamplingStates(max_num_reqs, vocab_size)
self.penalties_state = PenaltiesState(max_num_reqs, vocab_size, device)
self.penalties_state = PenaltiesState(req_states)
self.logit_bias_state = LogitBiasState(max_num_reqs, device)
self.bad_words_state = BadWordsState(all_token_ids, prompt_len, total_len)
self.bad_words_state = BadWordsState(req_states)
self.num_speculative_tokens = num_speculative_tokens
def add_request(
@@ -48,16 +47,9 @@ class Sampler:
self.logit_bias_state.add_request(req_idx, prompt_len, sampling_params)
self.bad_words_state.add_request(req_idx, sampling_params)
def apply_staged_writes(
self,
all_token_ids: torch.Tensor,
prefill_lens: np.ndarray,
prompt_lens: np.ndarray,
) -> None:
def apply_staged_writes(self) -> None:
self.sampling_states.apply_staged_writes()
self.penalties_state.apply_staged_writes(
all_token_ids, prefill_lens, prompt_lens
)
self.penalties_state.apply_staged_writes()
self.logit_bias_state.apply_staged_writes()
self.bad_words_state.apply_staged_writes()