[Model Runner V2] support bad_words sampling param (#33433)

Signed-off-by: zhuhaoran <zhuhaoran.zhr@alibaba-inc.com>
Signed-off-by: Woosuk Kwon <woosuk@inferact.ai>
Co-authored-by: Woosuk Kwon <woosuk@inferact.ai>
This commit is contained in:
zhrrr
2026-02-17 08:36:06 +08:00
committed by GitHub
parent 3b30e61507
commit 387a1898d9
7 changed files with 303 additions and 43 deletions

View File

@@ -156,8 +156,8 @@ def _prepare_prefill_inputs_kernel(
next_prefill_tokens_ptr,
idx_mapping_ptr,
query_start_loc_ptr,
prefill_token_ids_ptr,
prefill_token_ids_stride,
all_token_ids_ptr,
all_token_ids_stride,
prefill_lens_ptr,
num_computed_tokens_ptr,
BLOCK_SIZE: tl.constexpr,
@@ -174,16 +174,16 @@ def _prepare_prefill_inputs_kernel(
query_end = tl.load(query_start_loc_ptr + batch_idx + 1)
query_len = query_end - query_start
prefill_ptr = prefill_token_ids_ptr + req_state_idx * prefill_token_ids_stride
request_ptr = all_token_ids_ptr + req_state_idx * all_token_ids_stride
for i in range(0, query_len, BLOCK_SIZE):
block = i + tl.arange(0, BLOCK_SIZE)
mask = block < query_len
tokens = tl.load(prefill_ptr + num_computed + block, mask=mask)
tokens = tl.load(request_ptr + num_computed + block, mask=mask)
tl.store(input_ids_ptr + query_start + block, tokens, mask=mask)
next_pos = num_computed + query_len
if next_pos < prefill_len:
next_token = tl.load(prefill_ptr + next_pos)
next_token = tl.load(request_ptr + next_pos)
tl.store(next_prefill_tokens_ptr + req_state_idx, next_token)
@@ -192,7 +192,7 @@ def prepare_prefill_inputs(
next_prefill_tokens: torch.Tensor,
idx_mapping: torch.Tensor,
query_start_loc: torch.Tensor,
prefill_token_ids: torch.Tensor,
all_token_ids: torch.Tensor,
prefill_len: torch.Tensor,
num_computed_tokens: torch.Tensor,
) -> None:
@@ -202,8 +202,8 @@ def prepare_prefill_inputs(
next_prefill_tokens,
idx_mapping,
query_start_loc,
prefill_token_ids,
prefill_token_ids.stride(0),
all_token_ids,
all_token_ids.stride(0),
prefill_len,
num_computed_tokens,
BLOCK_SIZE=1024,
@@ -423,16 +423,21 @@ def _post_update_kernel(
num_sampled_ptr,
num_rejected_ptr,
query_start_loc_ptr,
all_token_ids_ptr,
all_token_ids_stride,
total_len_ptr,
):
req_id = tl.program_id(0)
req_state_idx = tl.load(idx_mapping_ptr + req_id)
total_len = tl.load(total_len_ptr + req_state_idx)
num_sampled = tl.load(num_sampled_ptr + req_id)
if num_sampled > 0:
token_id = tl.load(
sampled_tokens_ptr + req_id * sampled_tokens_stride + num_sampled - 1
)
tl.store(last_sampled_tokens_ptr + req_state_idx, token_id)
tl.store(total_len_ptr + req_state_idx, total_len + num_sampled)
for i in range(num_sampled):
token_id = tl.load(sampled_tokens_ptr + req_id * sampled_tokens_stride + i)
@@ -442,6 +447,10 @@ def _post_update_kernel(
count = tl.load(token_ptr)
count += 1
tl.store(token_ptr, count)
tl.store(
all_token_ids_ptr + req_state_idx * all_token_ids_stride + total_len + i,
token_id,
)
query_start = tl.load(query_start_loc_ptr + req_id)
query_end = tl.load(query_start_loc_ptr + req_id + 1)
@@ -470,6 +479,10 @@ def post_update(
num_rejected: torch.Tensor,
# [num_reqs + 1]
query_start_loc: torch.Tensor,
# [max_num_reqs, max_model_len]
all_token_ids: torch.Tensor,
# [max_num_reqs]
total_len: torch.Tensor,
) -> None:
num_reqs = idx_mapping.shape[0]
_post_update_kernel[(num_reqs,)](
@@ -483,6 +496,9 @@ def post_update(
num_sampled,
num_rejected,
query_start_loc,
all_token_ids,
all_token_ids.stride(0),
total_len,
num_warps=1,
)

View File

@@ -151,6 +151,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
max_num_reqs=self.max_num_reqs,
vocab_size=self.vocab_size,
device=self.device,
all_token_ids=self.req_states.all_token_ids.gpu,
prompt_len=self.req_states.prompt_len.gpu,
total_len=self.req_states.total_len.gpu,
logprobs_mode=self.model_config.logprobs_mode,
num_speculative_tokens=self.num_speculative_steps + 1,
)
@@ -448,7 +451,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.req_states.add_request(
req_id=req_id,
prompt_len=prompt_len,
prefill_token_ids=new_req_data.prefill_token_ids,
all_token_ids=new_req_data.prefill_token_ids,
num_computed_tokens=new_req_data.num_computed_tokens,
)
req_index = self.req_states.req_id_to_index[req_id]
@@ -479,9 +482,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if scheduler_output.scheduled_new_reqs:
self.req_states.apply_staged_writes()
self.sampler.apply_staged_writes(
self.req_states.prefill_token_ids.gpu,
self.req_states.all_token_ids.gpu,
self.req_states.prefill_len.np,
self.req_states.prompt_len,
self.req_states.prompt_len.np,
)
if self.uses_mrope:
self.mrope_states.apply_staged_writes()
@@ -570,7 +573,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.req_states.next_prefill_tokens,
idx_mapping,
query_start_loc,
self.req_states.prefill_token_ids.gpu,
self.req_states.all_token_ids.gpu,
self.req_states.prefill_len.gpu,
self.req_states.num_computed_tokens.gpu,
)
@@ -759,6 +762,8 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_sampled,
num_rejected,
input_batch.query_start_loc,
self.req_states.all_token_ids.gpu,
self.req_states.total_len.gpu,
)
# Update the number of computed prefill tokens.
@@ -924,9 +929,9 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.model.compute_logits,
hidden_states,
input_batch,
self.req_states.prefill_token_ids.gpu,
self.req_states.all_token_ids.gpu,
self.req_states.num_computed_tokens.gpu,
self.req_states.prompt_len,
self.req_states.prompt_len.np,
self.req_states.prefill_len.np,
self.req_states.num_computed_prefill_tokens,
)

View File

@@ -0,0 +1,209 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import torch
from vllm.sampling_params import SamplingParams
from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu.buffer_utils import StagedWriteTensor, UvaBackedTensor
MAX_BAD_WORDS_TOTAL_TOKENS = 1024 # Max total tokens for all bad words per request
MAX_NUM_BAD_WORDS = 128 # Max number of bad words per request
class BadWordsState:
def __init__(
self,
all_token_ids: torch.Tensor,
prompt_len: torch.Tensor,
total_len: torch.Tensor,
):
self.all_token_ids = all_token_ids
self.prompt_len = prompt_len
self.total_len = total_len
self.max_num_reqs = prompt_len.shape[0]
self.device = prompt_len.device
# flattened bad word tokens: [max_num_reqs, MAX_BAD_WORDS_TOTAL_TOKENS]
self.bad_word_token_ids = StagedWriteTensor(
(self.max_num_reqs, MAX_BAD_WORDS_TOTAL_TOKENS),
dtype=torch.int32,
device=self.device,
)
# cumulative offsets of bad words: [max_num_reqs, MAX_NUM_BAD_WORDS + 1]
self.bad_word_offsets = StagedWriteTensor(
(self.max_num_reqs, MAX_NUM_BAD_WORDS + 1),
dtype=torch.int32,
device=self.device,
)
# number of bad words per request
self.num_bad_words = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
# whether request uses bad words
self.use_bad_words = np.zeros(self.max_num_reqs, dtype=bool)
def add_request(
self,
req_idx: int,
sampling_params: SamplingParams,
) -> None:
bad_words_token_ids = sampling_params.bad_words_token_ids
if not bad_words_token_ids:
self.num_bad_words.np[req_idx] = 0
self.use_bad_words[req_idx] = False
return
num_bad_words = len(bad_words_token_ids)
if num_bad_words > MAX_NUM_BAD_WORDS:
raise ValueError(
f"Too many bad words: {num_bad_words}. "
f"The max number is {MAX_NUM_BAD_WORDS}."
)
# Flatten bad words and compute offsets
flattened_tokens: list[int] = []
offsets: list[int] = [0]
for bad_word in bad_words_token_ids:
flattened_tokens.extend(bad_word)
offsets.append(len(flattened_tokens))
if len(flattened_tokens) > MAX_BAD_WORDS_TOTAL_TOKENS:
raise ValueError(
f"Too many total bad word tokens: {len(flattened_tokens)}. "
f"The max is {MAX_BAD_WORDS_TOTAL_TOKENS}."
)
# Stage writes
self.bad_word_token_ids.stage_write(req_idx, 0, flattened_tokens)
self.bad_word_offsets.stage_write(req_idx, 0, offsets)
self.num_bad_words.np[req_idx] = num_bad_words
self.use_bad_words[req_idx] = True
def apply_staged_writes(self) -> None:
self.num_bad_words.copy_to_uva()
self.bad_word_token_ids.apply_write()
self.bad_word_offsets.apply_write()
def apply_bad_words(
self,
logits: torch.Tensor,
idx_mapping: torch.Tensor,
idx_mapping_np: np.ndarray,
input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
) -> None:
if not np.any(self.use_bad_words[idx_mapping_np]):
# No request uses bad words. Skip the kernel launch.
return
actual_max_num_bad_words = int(np.max(self.num_bad_words.np[idx_mapping_np]))
apply_bad_words(
logits,
idx_mapping,
self.bad_word_token_ids.gpu,
self.bad_word_offsets.gpu,
self.num_bad_words.gpu,
self.all_token_ids,
self.prompt_len,
self.total_len,
input_ids,
expanded_local_pos,
actual_max_num_bad_words,
)
@triton.jit
def _bad_words_kernel(
logits_ptr,
logits_stride,
expanded_idx_mapping_ptr,
bad_word_token_ids_ptr,
bad_word_token_ids_stride,
bad_word_offsets_ptr,
bad_word_offsets_stride,
num_bad_words_ptr,
all_token_ids_ptr,
all_token_ids_stride,
prompt_len_ptr,
total_len_ptr,
input_ids_ptr,
expanded_local_pos_ptr,
):
logit_idx = tl.program_id(0)
bw_idx = tl.program_id(1)
req_state_idx = tl.load(expanded_idx_mapping_ptr + logit_idx)
num_bad_words = tl.load(num_bad_words_ptr + req_state_idx)
if bw_idx >= num_bad_words:
return
pos = tl.load(expanded_local_pos_ptr + logit_idx)
cur_req_first_pos = logit_idx - pos
prompt_len = tl.load(prompt_len_ptr + req_state_idx)
total_len = tl.load(total_len_ptr + req_state_idx)
output_len = total_len - prompt_len
effective_len = output_len + pos
bd_offsets_base = bad_word_offsets_ptr + req_state_idx * bad_word_offsets_stride
bd_tokens_base = bad_word_token_ids_ptr + req_state_idx * bad_word_token_ids_stride
output_base = all_token_ids_ptr + req_state_idx * all_token_ids_stride + prompt_len
start = tl.load(bd_offsets_base + bw_idx)
end = tl.load(bd_offsets_base + bw_idx + 1)
bad_word_len = end - start
prefix_len = bad_word_len - 1
if prefix_len > effective_len:
return
last_token = tl.load(bd_tokens_base + end - 1)
match = 1
for i in range(prefix_len):
expected = tl.load(bd_tokens_base + start + i)
actual_pos = effective_len - prefix_len + i
from_spec_input = actual_pos >= output_len
if from_spec_input:
spec_offset = actual_pos - output_len
actual = tl.load(input_ids_ptr + cur_req_first_pos + spec_offset)
else:
actual = tl.load(output_base + actual_pos)
match = match & (expected == actual)
if match:
tl.store(logits_ptr + logit_idx * logits_stride + last_token, -float("inf"))
def apply_bad_words(
logits: torch.Tensor,
expanded_idx_mapping: torch.Tensor,
bad_word_token_ids: torch.Tensor,
bad_word_offsets: torch.Tensor,
num_bad_words: torch.Tensor,
all_token_ids: torch.Tensor,
prompt_len: torch.Tensor,
total_len: torch.Tensor,
input_ids: torch.Tensor,
expanded_local_pos: torch.Tensor,
max_num_bad_words: int,
) -> None:
total_num_tokens = logits.shape[0]
_bad_words_kernel[(total_num_tokens, max_num_bad_words)](
logits,
logits.stride(0),
expanded_idx_mapping,
bad_word_token_ids,
bad_word_token_ids.stride(0),
bad_word_offsets,
bad_word_offsets.stride(0),
num_bad_words,
all_token_ids,
all_token_ids.stride(0),
prompt_len,
total_len,
input_ids,
expanded_local_pos,
)

View File

@@ -51,14 +51,14 @@ class PenaltiesState:
def apply_staged_writes(
self,
prefill_token_ids: torch.Tensor,
all_token_ids: torch.Tensor,
prefill_lens: np.ndarray,
prompt_lens: np.ndarray,
) -> None:
# TODO(woosuk): Optimize this.
for req_idx in self._penalties_reqs:
bincount(
prefill_token_ids[req_idx],
all_token_ids[req_idx],
int(prefill_lens[req_idx]),
int(prompt_lens[req_idx]),
self.prompt_bin_mask[req_idx],
@@ -216,7 +216,7 @@ def apply_penalties(
@triton.jit(do_not_specialize=["prefill_len", "prompt_len"])
def _bincount_kernel(
prefill_token_ids_ptr,
all_token_ids_ptr,
prefill_len,
prompt_len,
prompt_bin_mask_ptr,
@@ -230,20 +230,20 @@ def _bincount_kernel(
block = block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
if block_idx * BLOCK_SIZE < prompt_len:
mask = block < prompt_len
prefill_tokens = tl.load(prefill_token_ids_ptr + block, mask=mask)
idx = prefill_tokens // 32
bit_idx = prefill_tokens % 32
prompt_tokens = tl.load(all_token_ids_ptr + block, mask=mask)
idx = prompt_tokens // 32
bit_idx = prompt_tokens % 32
bit = tl.full((BLOCK_SIZE,), 1, tl.int32) << bit_idx
tl.atomic_or(prompt_bin_mask_ptr + idx, bit, mask=mask)
if (block_idx + 1) * BLOCK_SIZE >= prompt_len:
mask = block < prefill_len
mask &= block >= prompt_len
prefill_tokens = tl.load(prefill_token_ids_ptr + block, mask=mask)
tl.atomic_add(output_bin_counts_ptr + prefill_tokens, 1, mask=mask)
output_tokens = tl.load(all_token_ids_ptr + block, mask=mask)
tl.atomic_add(output_bin_counts_ptr + output_tokens, 1, mask=mask)
def bincount(
prefill_token_ids: torch.Tensor,
all_token_ids: torch.Tensor,
prefill_len: int,
prompt_len: int,
prompt_bin_mask: torch.Tensor,
@@ -254,7 +254,7 @@ def bincount(
BLOCK_SIZE = 1024
num_blocks = triton.cdiv(prefill_len, BLOCK_SIZE)
_bincount_kernel[(num_blocks,)](
prefill_token_ids,
all_token_ids,
prefill_len,
prompt_len,
prompt_bin_mask,

View File

@@ -36,7 +36,7 @@ class PromptLogprobsWorker:
hidden_states: torch.Tensor,
input_batch: InputBatch,
# [max_num_reqs, max_model_len]
prefill_token_ids: torch.Tensor,
all_token_ids: torch.Tensor,
# [max_num_reqs]
num_computed_tokens: torch.Tensor,
# [max_num_reqs]
@@ -70,7 +70,7 @@ class PromptLogprobsWorker:
input_batch.query_start_loc,
input_batch.idx_mapping,
num_computed_tokens,
prefill_token_ids,
all_token_ids,
)
# Compute the prompt logprobs.
prompt_logprobs, prompt_ranks = compute_prompt_logprobs_with_chunking(
@@ -132,8 +132,8 @@ def _prompt_logprobs_token_ids_kernel(
query_start_loc_ptr,
idx_mapping_ptr,
num_computed_tokens_ptr,
prefill_token_ids_ptr,
prefill_token_ids_stride,
all_token_ids_ptr,
all_token_ids_stride,
BLOCK_SIZE: tl.constexpr,
):
batch_idx = tl.program_id(0)
@@ -151,9 +151,7 @@ def _prompt_logprobs_token_ids_kernel(
# because the logprob is computed for the next token.
target_pos = num_computed_tokens + 1 + block
token_ids = tl.load(
prefill_token_ids_ptr
+ req_state_idx * prefill_token_ids_stride
+ target_pos,
all_token_ids_ptr + req_state_idx * all_token_ids_stride + target_pos,
mask=mask,
)
tl.store(
@@ -166,7 +164,7 @@ def get_prompt_logprobs_token_ids(
query_start_loc: torch.Tensor,
idx_mapping: torch.Tensor,
num_computed_tokens: torch.Tensor,
prefill_token_ids: torch.Tensor,
all_token_ids: torch.Tensor,
) -> torch.Tensor:
token_ids = torch.empty(num_tokens, dtype=torch.int64, device=idx_mapping.device)
num_reqs = idx_mapping.shape[0]
@@ -175,8 +173,8 @@ def get_prompt_logprobs_token_ids(
query_start_loc,
idx_mapping,
num_computed_tokens,
prefill_token_ids,
prefill_token_ids.stride(0),
all_token_ids,
all_token_ids.stride(0),
BLOCK_SIZE=1024,
)
return token_ids

View File

@@ -8,6 +8,7 @@ import vllm.envs as envs
from vllm.config.model import LogprobsMode
from vllm.sampling_params import SamplingParams
from vllm.v1.worker.gpu.metrics.logits import get_num_nans
from vllm.v1.worker.gpu.sample.bad_words import BadWordsState
from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample
from vllm.v1.worker.gpu.sample.logit_bias import LogitBiasState
from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs
@@ -22,6 +23,9 @@ class Sampler:
max_num_reqs: int,
vocab_size: int,
device: torch.device,
all_token_ids: torch.Tensor,
prompt_len: torch.Tensor,
total_len: torch.Tensor,
logprobs_mode: LogprobsMode = "raw_logprobs",
num_speculative_tokens: int = 1,
):
@@ -33,6 +37,7 @@ class Sampler:
self.sampling_states = SamplingStates(max_num_reqs, vocab_size)
self.penalties_state = PenaltiesState(max_num_reqs, vocab_size, device)
self.logit_bias_state = LogitBiasState(max_num_reqs, device)
self.bad_words_state = BadWordsState(all_token_ids, prompt_len, total_len)
self.num_speculative_tokens = num_speculative_tokens
def add_request(
@@ -41,18 +46,20 @@ class Sampler:
self.sampling_states.add_request(req_idx, sampling_params)
self.penalties_state.add_request(req_idx, sampling_params)
self.logit_bias_state.add_request(req_idx, prompt_len, sampling_params)
self.bad_words_state.add_request(req_idx, sampling_params)
def apply_staged_writes(
self,
prefill_token_ids: torch.Tensor,
all_token_ids: torch.Tensor,
prefill_lens: np.ndarray,
prompt_lens: np.ndarray,
) -> None:
self.sampling_states.apply_staged_writes()
self.penalties_state.apply_staged_writes(
prefill_token_ids, prefill_lens, prompt_lens
all_token_ids, prefill_lens, prompt_lens
)
self.logit_bias_state.apply_staged_writes()
self.bad_words_state.apply_staged_writes()
def __call__(
self,
@@ -124,6 +131,15 @@ class Sampler:
self.num_speculative_tokens,
)
# Apply bad words masking in place.
self.bad_words_state.apply_bad_words(
logits,
idx_mapping,
idx_mapping_np,
input_ids,
expanded_local_pos,
)
# Apply temperature in place.
self.sampling_states.apply_temperature(logits, idx_mapping, idx_mapping_np)

View File

@@ -27,17 +27,30 @@ class RequestState:
self.index_to_req_id: dict[int, str] = {}
self.free_indices = list(range(max_num_reqs))
self.prompt_len = np.zeros(self.max_num_reqs, dtype=np.int32)
# NOTE(woosuk): This tensor can be extremely large (e.g., several GBs)
# depending on the configured max_num_reqs and max_model_len.
# To save GPU memory, we use UVA instead of GPU for this tensor.
self.prefill_token_ids = StagedWriteTensor(
self.all_token_ids = StagedWriteTensor(
(self.max_num_reqs, self.max_model_len),
dtype=torch.int32,
device=device,
uva_instead_of_gpu=True,
)
# NOTE(woosuk): Distinguish clearly between prompt_len and prefill_len:
# - prompt_len: Number of tokens in the user-provided prompt.
# - prefill_len: Number of tokens passed into the model runner.
# This can include the prompt and additional partial output tokens,
# so prefill_len >= prompt_len.
# Usually, prefill_len equals prompt_len, but in cases such as resumption after
# preemption, prefill_len may be greater. Differentiating between these values
# is crucial, as certain features such as prompt logprobs or frequency penalties
# must treat prompt and output tokens separately.
self.prompt_len = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
self.prefill_len = UvaBackedTensor(self.max_num_reqs, dtype=torch.int32)
# total_len = prompt_len + output_len. It grows as the request progresses.
self.total_len = StagedWriteTensor(
self.max_num_reqs, dtype=torch.int32, device=device
)
# Number of computed tokens.
self.num_computed_prefill_tokens = np.zeros(self.max_num_reqs, dtype=np.int32)
@@ -72,7 +85,7 @@ class RequestState:
self,
req_id: str,
prompt_len: int,
prefill_token_ids: list[int],
all_token_ids: list[int],
num_computed_tokens: int,
) -> None:
assert len(self.free_indices) > 0, "No free indices"
@@ -80,19 +93,22 @@ class RequestState:
self.req_id_to_index[req_id] = req_idx
self.index_to_req_id[req_idx] = req_id
self.prompt_len[req_idx] = prompt_len
prefill_len = len(prefill_token_ids)
self.prompt_len.np[req_idx] = prompt_len
prefill_len = len(all_token_ids)
assert prefill_len >= prompt_len, (
f"prefill_len {prefill_len} < prompt_len {prompt_len}"
)
self.prefill_len.np[req_idx] = prefill_len
self.prefill_token_ids.stage_write(req_idx, 0, prefill_token_ids)
self.total_len.stage_write_elem(req_idx, prefill_len)
self.all_token_ids.stage_write(req_idx, 0, all_token_ids)
self.num_computed_prefill_tokens[req_idx] = num_computed_tokens
self.num_computed_tokens.stage_write_elem(req_idx, num_computed_tokens)
def apply_staged_writes(self) -> None:
self.prompt_len.copy_to_uva()
self.prefill_len.copy_to_uva()
self.prefill_token_ids.apply_write()
self.total_len.apply_write()
self.all_token_ids.apply_write()
self.num_computed_tokens.apply_write()
def remove_request(self, req_id: str) -> None: