[V1] Support bad_words in sampler (#13376)

Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
22quinn
2025-03-08 14:50:26 -08:00
committed by GitHub
parent 9513290032
commit eb8b5eb183
13 changed files with 266 additions and 28 deletions

View File

@@ -10,6 +10,7 @@ import torch
from vllm.lora.request import LoRARequest
from vllm.multimodal import MultiModalKwargs
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import swap_dict_values
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.utils import copy_slice
from vllm.v1.worker.block_table import BlockTable
@@ -204,6 +205,9 @@ class InputBatch:
self.allowed_token_ids_mask: Optional[torch.Tensor] = None
self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
# req_index -> bad_words_token_ids
self.bad_words_token_ids: dict[int, list[list[int]]] = {}
self.req_output_token_ids: list[Optional[list[int]]] = []
# This is updated each time the batch constituents change.
@@ -320,6 +324,9 @@ class InputBatch:
self.allowed_token_ids_mask_cpu_tensor[req_index][
sampling_params.allowed_token_ids] = False
self.bad_words_token_ids[
req_index] = sampling_params.bad_words_token_ids
# Add request lora ID
if request.lora_request:
lora_id = request.lora_request.lora_int_id
@@ -369,6 +376,7 @@ class InputBatch:
if self.allowed_token_ids_mask_cpu_tensor is not None:
# False means we don't fill with -inf.
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
self.bad_words_token_ids.pop(req_index, None)
return req_index
def swap_states(self, i1: int, i2: int) -> None:
@@ -413,27 +421,9 @@ class InputBatch:
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
self.token_ids_cpu[i2, ...] = tmp
g1 = self.generators.get(i1)
g2 = self.generators.get(i2)
if g1 is not None:
self.generators[i2] = g1
else:
self.generators.pop(i2, None)
if g2 is not None:
self.generators[i1] = g2
else:
self.generators.pop(i1, None)
t1 = self.min_tokens.get(i1)
t2 = self.min_tokens.get(i2)
if t1 is not None:
self.min_tokens[i2] = t1
else:
self.min_tokens.pop(i2, None)
if t2 is not None:
self.min_tokens[i1] = t2
else:
self.min_tokens.pop(i1, None)
swap_dict_values(self.generators, i1, i2)
swap_dict_values(self.min_tokens, i1, i2)
swap_dict_values(self.bad_words_token_ids, i1, i2)
self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
@@ -518,6 +508,10 @@ class InputBatch:
empty_index] = self.allowed_token_ids_mask_cpu_tensor[
last_req_index]
bad_words_token_ids = self.bad_words_token_ids.pop(
last_req_index, None)
if bad_words_token_ids is not None:
self.bad_words_token_ids[empty_index] = bad_words_token_ids
# Decrement last_req_index since it is now empty.
last_req_index -= 1
@@ -585,6 +579,7 @@ class InputBatch:
no_penalties=self.no_penalties,
logit_bias=self.logit_bias[:num_reqs],
allowed_token_ids_mask=allowed_token_ids_mask,
bad_words_token_ids=self.bad_words_token_ids,
)
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:

View File

@@ -1268,6 +1268,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
min_tokens={},
logit_bias=[None for _ in range(num_reqs)],
allowed_token_ids_mask=None,
bad_words_token_ids={},
)
sampler_output = self.model.sample(logits=logits,
sampling_metadata=dummy_metadata)