[Bugfix][Frontend] Guard against bad token ids (#9634)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
@@ -10,7 +10,7 @@ from typing import Sequence as GenericSequence
|
||||
from typing import Set, Type, Union, cast, overload
|
||||
|
||||
import torch
|
||||
from typing_extensions import TypeVar
|
||||
from typing_extensions import TypeIs, TypeVar
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig,
|
||||
@@ -32,7 +32,8 @@ from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.executor.gpu_executor import GPUExecutor
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
|
||||
EncoderDecoderInputs, InputRegistry, PromptType)
|
||||
EncoderDecoderInputs, InputRegistry, PromptType,
|
||||
TokensPrompt)
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logits_process import get_bad_words_logits_processors
|
||||
@@ -667,7 +668,7 @@ class LLMEngine:
|
||||
)
|
||||
return None
|
||||
|
||||
self._validate_model_inputs(processed_inputs)
|
||||
self._validate_model_inputs(processed_inputs, lora_request)
|
||||
# Create the sequences.
|
||||
block_size = self.cache_config.block_size
|
||||
seq_id = next(self.seq_counter)
|
||||
@@ -829,6 +830,11 @@ class LLMEngine:
|
||||
if arrival_time is None:
|
||||
arrival_time = time.time()
|
||||
|
||||
if self.tokenizer is not None:
|
||||
self._validate_token_prompt(
|
||||
prompt,
|
||||
tokenizer=self.get_tokenizer(lora_request=lora_request))
|
||||
|
||||
preprocessed_inputs = self.input_preprocessor.preprocess(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
@@ -855,6 +861,31 @@ class LLMEngine:
|
||||
priority=priority,
|
||||
)
|
||||
|
||||
def _validate_token_prompt(self, prompt: PromptType,
|
||||
tokenizer: AnyTokenizer):
|
||||
# Guard against out-of-vocab tokens.
|
||||
# For some tokenizers, tokenizer.decode will happily return empty text
|
||||
# for token ids that are out of vocab, and we don't detect token ids
|
||||
# that are greater than the max token id before running the model.
|
||||
# However, these token ids will later crash a cuda kernel at runtime
|
||||
# with an index out of bounds error. This will crash the entire engine.
|
||||
# This needs to happen before multimodal input pre-processing, which
|
||||
# may add dummy <image> tokens that aren't part of the tokenizer's
|
||||
# vocabulary.
|
||||
if self._is_token_prompt(prompt):
|
||||
prompt_ids = prompt["prompt_token_ids"]
|
||||
if len(prompt_ids) == 0:
|
||||
# Empty prompt check is handled later
|
||||
return
|
||||
max_input_id = max(prompt_ids)
|
||||
if max_input_id > tokenizer.max_token_id:
|
||||
raise ValueError(
|
||||
"Token id {} is out of vocabulary".format(max_input_id))
|
||||
|
||||
@staticmethod
|
||||
def _is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]:
|
||||
return isinstance(prompt, dict) and "prompt_token_ids" in prompt
|
||||
|
||||
def _create_sequence_group_with_sampling(
|
||||
self,
|
||||
request_id: str,
|
||||
@@ -1942,7 +1973,8 @@ class LLMEngine:
|
||||
return self.input_preprocessor.is_encoder_decoder_model()
|
||||
|
||||
def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs,
|
||||
EncoderDecoderInputs]):
|
||||
EncoderDecoderInputs],
|
||||
lora_request: Optional[LoRARequest]):
|
||||
if self.model_config.is_multimodal_model:
|
||||
# For encoder-decoder multimodal models, the max_prompt_len
|
||||
# restricts the decoder prompt length
|
||||
|
||||
Reference in New Issue
Block a user