[V1] Support bad_words in sampler (#13376)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -10,6 +10,7 @@ import torch
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.utils import swap_dict_values
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.utils import copy_slice
|
||||
from vllm.v1.worker.block_table import BlockTable
|
||||
@@ -204,6 +205,9 @@ class InputBatch:
|
||||
self.allowed_token_ids_mask: Optional[torch.Tensor] = None
|
||||
self.allowed_token_ids_mask_cpu_tensor: Optional[torch.Tensor] = None
|
||||
|
||||
# req_index -> bad_words_token_ids
|
||||
self.bad_words_token_ids: dict[int, list[list[int]]] = {}
|
||||
|
||||
self.req_output_token_ids: list[Optional[list[int]]] = []
|
||||
|
||||
# This is updated each time the batch constituents change.
|
||||
@@ -320,6 +324,9 @@ class InputBatch:
|
||||
self.allowed_token_ids_mask_cpu_tensor[req_index][
|
||||
sampling_params.allowed_token_ids] = False
|
||||
|
||||
self.bad_words_token_ids[
|
||||
req_index] = sampling_params.bad_words_token_ids
|
||||
|
||||
# Add request lora ID
|
||||
if request.lora_request:
|
||||
lora_id = request.lora_request.lora_int_id
|
||||
@@ -369,6 +376,7 @@ class InputBatch:
|
||||
if self.allowed_token_ids_mask_cpu_tensor is not None:
|
||||
# False means we don't fill with -inf.
|
||||
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
|
||||
self.bad_words_token_ids.pop(req_index, None)
|
||||
return req_index
|
||||
|
||||
def swap_states(self, i1: int, i2: int) -> None:
|
||||
@@ -413,27 +421,9 @@ class InputBatch:
|
||||
self.token_ids_cpu[i1, ...] = self.token_ids_cpu[i2, ...]
|
||||
self.token_ids_cpu[i2, ...] = tmp
|
||||
|
||||
g1 = self.generators.get(i1)
|
||||
g2 = self.generators.get(i2)
|
||||
if g1 is not None:
|
||||
self.generators[i2] = g1
|
||||
else:
|
||||
self.generators.pop(i2, None)
|
||||
if g2 is not None:
|
||||
self.generators[i1] = g2
|
||||
else:
|
||||
self.generators.pop(i1, None)
|
||||
|
||||
t1 = self.min_tokens.get(i1)
|
||||
t2 = self.min_tokens.get(i2)
|
||||
if t1 is not None:
|
||||
self.min_tokens[i2] = t1
|
||||
else:
|
||||
self.min_tokens.pop(i2, None)
|
||||
if t2 is not None:
|
||||
self.min_tokens[i1] = t2
|
||||
else:
|
||||
self.min_tokens.pop(i1, None)
|
||||
swap_dict_values(self.generators, i1, i2)
|
||||
swap_dict_values(self.min_tokens, i1, i2)
|
||||
swap_dict_values(self.bad_words_token_ids, i1, i2)
|
||||
|
||||
self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
|
||||
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
|
||||
@@ -518,6 +508,10 @@ class InputBatch:
|
||||
empty_index] = self.allowed_token_ids_mask_cpu_tensor[
|
||||
last_req_index]
|
||||
|
||||
bad_words_token_ids = self.bad_words_token_ids.pop(
|
||||
last_req_index, None)
|
||||
if bad_words_token_ids is not None:
|
||||
self.bad_words_token_ids[empty_index] = bad_words_token_ids
|
||||
# Decrement last_req_index since it is now empty.
|
||||
last_req_index -= 1
|
||||
|
||||
@@ -585,6 +579,7 @@ class InputBatch:
|
||||
no_penalties=self.no_penalties,
|
||||
logit_bias=self.logit_bias[:num_reqs],
|
||||
allowed_token_ids_mask=allowed_token_ids_mask,
|
||||
bad_words_token_ids=self.bad_words_token_ids,
|
||||
)
|
||||
|
||||
def _make_prompt_token_ids_tensor(self) -> torch.Tensor:
|
||||
|
||||
@@ -1268,6 +1268,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
min_tokens={},
|
||||
logit_bias=[None for _ in range(num_reqs)],
|
||||
allowed_token_ids_mask=None,
|
||||
bad_words_token_ids={},
|
||||
)
|
||||
sampler_output = self.model.sample(logits=logits,
|
||||
sampling_metadata=dummy_metadata)
|
||||
|
||||
Reference in New Issue
Block a user