From d00df624f313a6a5a7a6245b71448b068b080cd7 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 16 Feb 2026 21:43:00 -0800 Subject: [PATCH] [Model Runner V2] Minor refactoring for penalties (#34662) Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu/model_runner.py | 10 +-- vllm/v1/worker/gpu/sample/bad_words.py | 23 ++--- vllm/v1/worker/gpu/sample/penalties.py | 116 +++++++++++++++++-------- vllm/v1/worker/gpu/sample/sampler.py | 20 ++--- 4 files changed, 94 insertions(+), 75 deletions(-) diff --git a/vllm/v1/worker/gpu/model_runner.py b/vllm/v1/worker/gpu/model_runner.py index 273cecd3b..0ca0e828b 100644 --- a/vllm/v1/worker/gpu/model_runner.py +++ b/vllm/v1/worker/gpu/model_runner.py @@ -155,9 +155,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): max_num_reqs=self.max_num_reqs, vocab_size=self.vocab_size, device=self.device, - all_token_ids=self.req_states.all_token_ids.gpu, - prompt_len=self.req_states.prompt_len.gpu, - total_len=self.req_states.total_len.gpu, + req_states=self.req_states, logprobs_mode=self.model_config.logprobs_mode, num_speculative_tokens=self.num_speculative_steps + 1, ) @@ -528,11 +526,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): if scheduler_output.scheduled_new_reqs: self.req_states.apply_staged_writes() - self.sampler.apply_staged_writes( - self.req_states.all_token_ids.gpu, - self.req_states.prefill_len.np, - self.req_states.prompt_len.np, - ) + self.sampler.apply_staged_writes() if self.uses_mrope: self.mrope_states.apply_staged_writes() diff --git a/vllm/v1/worker/gpu/sample/bad_words.py b/vllm/v1/worker/gpu/sample/bad_words.py index 4c156e811..2c7dc1327 100644 --- a/vllm/v1/worker/gpu/sample/bad_words.py +++ b/vllm/v1/worker/gpu/sample/bad_words.py @@ -6,24 +6,17 @@ import torch from vllm.sampling_params import SamplingParams from vllm.triton_utils import tl, triton from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor +from vllm.v1.worker.gpu.states import RequestState MAX_BAD_WORDS_TOTAL_TOKENS = 1024 # Max total tokens for all bad words per request MAX_NUM_BAD_WORDS = 128 # Max number of bad words per request class BadWordsState: - def __init__( - self, - all_token_ids: torch.Tensor, - prompt_len: torch.Tensor, - total_len: torch.Tensor, - ): - self.all_token_ids = all_token_ids - self.prompt_len = prompt_len - self.total_len = total_len - - self.max_num_reqs = prompt_len.shape[0] - self.device = prompt_len.device + def __init__(self, req_states: RequestState): + self.req_states = req_states + self.max_num_reqs = req_states.max_num_reqs + self.device = req_states.device # flattened bad word tokens: [max_num_reqs, MAX_BAD_WORDS_TOTAL_TOKENS] self.bad_word_token_ids = StagedWriteTensor( @@ -95,9 +88,9 @@ class BadWordsState: self.bad_word_token_ids.gpu, self.bad_word_offsets.gpu, self.num_bad_words.gpu, - self.all_token_ids, - self.prompt_len, - self.total_len, + self.req_states.all_token_ids.gpu, + self.req_states.prompt_len.gpu, + self.req_states.total_len.gpu, input_ids, expanded_local_pos, max_num_bad_words, diff --git a/vllm/v1/worker/gpu/sample/penalties.py b/vllm/v1/worker/gpu/sample/penalties.py index 8671dd7e0..e926d550f 100644 --- a/vllm/v1/worker/gpu/sample/penalties.py +++ b/vllm/v1/worker/gpu/sample/penalties.py @@ -6,14 +6,18 @@ import torch from vllm.sampling_params import SamplingParams from vllm.triton_utils import tl, triton from vllm.utils.math_utils import cdiv +from vllm.utils.torch_utils import async_tensor_h2d from vllm.v1.worker.gpu.buffer_utils import UvaBackedTensor +from vllm.v1.worker.gpu.states import RequestState class PenaltiesState: - def __init__(self, max_num_reqs: int, vocab_size: int, device: torch.device): - self.max_num_reqs = max_num_reqs - self.vocab_size = vocab_size - self.device = device + def __init__(self, req_states: RequestState): + self.req_states = req_states + + max_num_reqs = req_states.max_num_reqs + self.vocab_size = req_states.vocab_size + self.device = req_states.device self.repetition_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32) self.frequency_penalty = UvaBackedTensor(max_num_reqs, dtype=torch.float32) @@ -26,7 +30,7 @@ class PenaltiesState: # Statistics for penalties. self.prompt_bin_mask = torch.zeros( - self.max_num_reqs, + max_num_reqs, cdiv(self.vocab_size, 32), dtype=torch.int32, device=self.device, @@ -34,10 +38,10 @@ class PenaltiesState: # TODO(woosuk): This tensor is rarely used but can be very large, taking up # GBs of GPU memory. Optimize the memory usage. self.output_bin_counts = torch.zeros( - self.max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device + max_num_reqs, self.vocab_size, dtype=torch.int32, device=self.device ) - self._penalties_reqs: list[int] = [] + self._new_penalties_reqs: list[int] = [] def add_request(self, req_idx: int, sampling_params: SamplingParams) -> None: self.repetition_penalty.np[req_idx] = sampling_params.repetition_penalty @@ -47,24 +51,29 @@ class PenaltiesState: do_penalty = use_penalty(sampling_params) self.use_penalty[req_idx] = do_penalty if do_penalty: - self._penalties_reqs.append(req_idx) + self._new_penalties_reqs.append(req_idx) - def apply_staged_writes( - self, - all_token_ids: torch.Tensor, - prefill_lens: np.ndarray, - prompt_lens: np.ndarray, - ) -> None: - # TODO(woosuk): Optimize this. - for req_idx in self._penalties_reqs: - bincount( - all_token_ids[req_idx], - int(prefill_lens[req_idx]), - int(prompt_lens[req_idx]), - self.prompt_bin_mask[req_idx], - self.output_bin_counts[req_idx], + def apply_staged_writes(self) -> None: + if self._new_penalties_reqs: + idx_mapping = async_tensor_h2d( + self._new_penalties_reqs, + dtype=torch.int32, + target_device=self.device, + pin_memory=True, ) - self._penalties_reqs.clear() + + prefill_lens = self.req_states.prefill_len.np[self._new_penalties_reqs] + max_prefill_len = int(prefill_lens.max()) + bincount( + idx_mapping, + self.req_states.all_token_ids.gpu, + self.req_states.prompt_len.gpu, + self.req_states.prefill_len.gpu, + self.prompt_bin_mask, + self.output_bin_counts, + max_prefill_len, + ) + self._new_penalties_reqs.clear() self.repetition_penalty.copy_to_uva() self.frequency_penalty.copy_to_uva() @@ -214,51 +223,82 @@ def apply_penalties( ) -@triton.jit(do_not_specialize=["prefill_len", "prompt_len"]) +@triton.jit def _bincount_kernel( + idx_mapping_ptr, all_token_ids_ptr, - prefill_len, - prompt_len, + all_token_ids_stride, + prompt_len_ptr, + prefill_len_ptr, prompt_bin_mask_ptr, + prompt_bin_mask_stride, output_bin_counts_ptr, + output_bin_counts_stride, BLOCK_SIZE: tl.constexpr, ): - block_idx = tl.program_id(0) + batch_idx = tl.program_id(0) + block_idx = tl.program_id(1) + req_state_idx = tl.load(idx_mapping_ptr + batch_idx) + + prefill_len = tl.load(prefill_len_ptr + req_state_idx) if block_idx * BLOCK_SIZE >= prefill_len: return + prompt_len = tl.load(prompt_len_ptr + req_state_idx) block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) if block_idx * BLOCK_SIZE < prompt_len: mask = block < prompt_len - prompt_tokens = tl.load(all_token_ids_ptr + block, mask=mask) + prompt_tokens = tl.load( + all_token_ids_ptr + req_state_idx * all_token_ids_stride + block, mask=mask + ) idx = prompt_tokens // 32 bit_idx = prompt_tokens % 32 bit = tl.full((BLOCK_SIZE,), 1, tl.int32) << bit_idx - tl.atomic_or(prompt_bin_mask_ptr + idx, bit, mask=mask) + tl.atomic_or( + prompt_bin_mask_ptr + req_state_idx * prompt_bin_mask_stride + idx, + bit, + mask=mask, + ) + if (block_idx + 1) * BLOCK_SIZE >= prompt_len: mask = block < prefill_len mask &= block >= prompt_len - output_tokens = tl.load(all_token_ids_ptr + block, mask=mask) - tl.atomic_add(output_bin_counts_ptr + output_tokens, 1, mask=mask) + output_tokens = tl.load( + all_token_ids_ptr + req_state_idx * all_token_ids_stride + block, mask=mask + ) + tl.atomic_add( + output_bin_counts_ptr + + req_state_idx * output_bin_counts_stride + + output_tokens, + 1, + mask=mask, + ) def bincount( + idx_mapping: torch.Tensor, all_token_ids: torch.Tensor, - prefill_len: int, - prompt_len: int, + prompt_len: torch.Tensor, + prefill_len: torch.Tensor, prompt_bin_mask: torch.Tensor, output_bin_counts: torch.Tensor, + max_prefill_len: int, ) -> None: - prompt_bin_mask.zero_() - output_bin_counts.zero_() + prompt_bin_mask[idx_mapping] = 0 + output_bin_counts[idx_mapping] = 0 + num_reqs = idx_mapping.shape[0] BLOCK_SIZE = 1024 - num_blocks = triton.cdiv(prefill_len, BLOCK_SIZE) - _bincount_kernel[(num_blocks,)]( + num_blocks = triton.cdiv(max_prefill_len, BLOCK_SIZE) + _bincount_kernel[(num_reqs, num_blocks)]( + idx_mapping, all_token_ids, - prefill_len, + all_token_ids.stride(0), prompt_len, + prefill_len, prompt_bin_mask, + prompt_bin_mask.stride(0), output_bin_counts, + output_bin_counts.stride(0), BLOCK_SIZE=BLOCK_SIZE, ) diff --git a/vllm/v1/worker/gpu/sample/sampler.py b/vllm/v1/worker/gpu/sample/sampler.py index d5f66a39e..87b10bcc1 100644 --- a/vllm/v1/worker/gpu/sample/sampler.py +++ b/vllm/v1/worker/gpu/sample/sampler.py @@ -15,6 +15,7 @@ from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs from vllm.v1.worker.gpu.sample.output import SamplerOutput from vllm.v1.worker.gpu.sample.penalties import PenaltiesState from vllm.v1.worker.gpu.sample.states import NO_LOGPROBS, SamplingStates +from vllm.v1.worker.gpu.states import RequestState class Sampler: @@ -23,9 +24,7 @@ class Sampler: max_num_reqs: int, vocab_size: int, device: torch.device, - all_token_ids: torch.Tensor, - prompt_len: torch.Tensor, - total_len: torch.Tensor, + req_states: RequestState, logprobs_mode: LogprobsMode = "raw_logprobs", num_speculative_tokens: int = 1, ): @@ -35,9 +34,9 @@ class Sampler: self.compute_nans = envs.VLLM_COMPUTE_NANS_IN_LOGITS # False by default. self.sampling_states = SamplingStates(max_num_reqs, vocab_size) - self.penalties_state = PenaltiesState(max_num_reqs, vocab_size, device) + self.penalties_state = PenaltiesState(req_states) self.logit_bias_state = LogitBiasState(max_num_reqs, device) - self.bad_words_state = BadWordsState(all_token_ids, prompt_len, total_len) + self.bad_words_state = BadWordsState(req_states) self.num_speculative_tokens = num_speculative_tokens def add_request( @@ -48,16 +47,9 @@ class Sampler: self.logit_bias_state.add_request(req_idx, prompt_len, sampling_params) self.bad_words_state.add_request(req_idx, sampling_params) - def apply_staged_writes( - self, - all_token_ids: torch.Tensor, - prefill_lens: np.ndarray, - prompt_lens: np.ndarray, - ) -> None: + def apply_staged_writes(self) -> None: self.sampling_states.apply_staged_writes() - self.penalties_state.apply_staged_writes( - all_token_ids, prefill_lens, prompt_lens - ) + self.penalties_state.apply_staged_writes() self.logit_bias_state.apply_staged_writes() self.bad_words_state.apply_staged_writes()