[Bugfix][Frontend] Guard against bad token ids (#9634)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
Joe Runde
2024-10-29 16:13:20 -05:00
committed by GitHub
parent 0ad216f575
commit 67bdf8e523
7 changed files with 89 additions and 17 deletions

View File

@@ -10,7 +10,7 @@ from typing import Sequence as GenericSequence
from typing import Set, Type, Union, cast, overload
import torch
from typing_extensions import TypeVar
from typing_extensions import TypeIs, TypeVar
import vllm.envs as envs
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
@@ -32,7 +32,8 @@ from vllm.executor.executor_base import ExecutorBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
EncoderDecoderInputs, InputRegistry, PromptType)
EncoderDecoderInputs, InputRegistry, PromptType,
TokensPrompt)
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.logits_process import get_bad_words_logits_processors
@@ -667,7 +668,7 @@ class LLMEngine:
)
return None
self._validate_model_inputs(processed_inputs)
self._validate_model_inputs(processed_inputs, lora_request)
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
@@ -829,6 +830,11 @@ class LLMEngine:
if arrival_time is None:
arrival_time = time.time()
if self.tokenizer is not None:
self._validate_token_prompt(
prompt,
tokenizer=self.get_tokenizer(lora_request=lora_request))
preprocessed_inputs = self.input_preprocessor.preprocess(
prompt,
request_id=request_id,
@@ -855,6 +861,31 @@ class LLMEngine:
priority=priority,
)
def _validate_token_prompt(self, prompt: PromptType,
tokenizer: AnyTokenizer):
# Guard against out-of-vocab tokens.
# For some tokenizers, tokenizer.decode will happily return empty text
# for token ids that are out of vocab, and we don't detect token ids
# that are greater than the max token id before running the model.
# However, these token ids will later crash a cuda kernel at runtime
# with an index out of bounds error. This will crash the entire engine.
# This needs to happen before multimodal input pre-processing, which
# may add dummy <image> tokens that aren't part of the tokenizer's
# vocabulary.
if self._is_token_prompt(prompt):
prompt_ids = prompt["prompt_token_ids"]
if len(prompt_ids) == 0:
# Empty prompt check is handled later
return
max_input_id = max(prompt_ids)
if max_input_id > tokenizer.max_token_id:
raise ValueError(
"Token id {} is out of vocabulary".format(max_input_id))
@staticmethod
def _is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]:
return isinstance(prompt, dict) and "prompt_token_ids" in prompt
def _create_sequence_group_with_sampling(
self,
request_id: str,
@@ -1942,7 +1973,8 @@ class LLMEngine:
return self.input_preprocessor.is_encoder_decoder_model()
def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs,
EncoderDecoderInputs]):
EncoderDecoderInputs],
lora_request: Optional[LoRARequest]):
if self.model_config.is_multimodal_model:
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length