466 lines
17 KiB
Python
466 lines
17 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
from dataclasses import dataclass, field
|
|
from typing import TYPE_CHECKING, Any, Literal, TypeVar
|
|
|
|
from vllm.exceptions import VLLMValidationError
|
|
from vllm.inputs import EmbedsPrompt, TextPrompt, TokensPrompt
|
|
from vllm.logger import init_logger
|
|
from vllm.multimodal.media.connector import merge_media_io_kwargs
|
|
from vllm.tokenizers import TokenizerLike
|
|
from vllm.utils.import_utils import LazyLoader
|
|
|
|
if TYPE_CHECKING:
|
|
import torch
|
|
|
|
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
|
|
else:
|
|
torch = LazyLoader("torch", globals(), "torch")
|
|
|
|
ChatTemplateContentFormatOption = object
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
_S = TypeVar("_S", list[int], "torch.Tensor")
|
|
|
|
|
|
def merge_kwargs(
|
|
defaults: dict[str, Any] | None,
|
|
overrides: dict[str, Any] | None,
|
|
/,
|
|
*,
|
|
unset_values: tuple[object, ...] = (None, "auto"),
|
|
) -> dict[str, Any]:
|
|
if defaults is None:
|
|
defaults = {}
|
|
if overrides is None:
|
|
overrides = {}
|
|
|
|
return defaults | {k: v for k, v in overrides.items() if v not in unset_values}
|
|
|
|
|
|
def recursively_merge_kwargs(
|
|
defaults: dict[str, Any] | None,
|
|
overrides: dict[str, Any] | None,
|
|
/,
|
|
*,
|
|
unset_values: tuple[object, ...] = (None, "auto"),
|
|
) -> dict[str, Any]:
|
|
if defaults is None:
|
|
defaults = {}
|
|
if overrides is None:
|
|
overrides = {}
|
|
|
|
merged = dict(defaults)
|
|
|
|
for k, v in overrides.items():
|
|
if v in unset_values:
|
|
continue
|
|
|
|
if k in merged and isinstance(merged[k], dict) and isinstance(v, dict):
|
|
merged[k] = recursively_merge_kwargs(
|
|
merged[k], v, unset_values=unset_values
|
|
)
|
|
else:
|
|
merged[k] = v
|
|
|
|
return merged
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class ChatParams:
|
|
"""Configuration to control how to parse chat messages."""
|
|
|
|
chat_template: str | None = None
|
|
"""The chat template to apply."""
|
|
|
|
chat_template_content_format: "ChatTemplateContentFormatOption" = "auto"
|
|
"""The format of the chat template."""
|
|
|
|
chat_template_kwargs: dict[str, Any] = field(default_factory=dict)
|
|
"""The kwargs to pass to the chat template."""
|
|
|
|
media_io_kwargs: dict[str, dict[str, Any]] | None = None
|
|
"""Per-modality kwargs for media I/O (loading/decoding images, videos, etc.)."""
|
|
|
|
mm_processor_kwargs: dict[str, Any] | None = None
|
|
"""The kwargs to pass to the multi-modal processor."""
|
|
|
|
def with_defaults(
|
|
self,
|
|
default_chat_template_kwargs: dict[str, Any] | None = None,
|
|
default_media_io_kwargs: dict[str, dict[str, Any]] | None = None,
|
|
default_mm_processor_kwargs: dict[str, Any] | None = None,
|
|
):
|
|
if (
|
|
not default_chat_template_kwargs
|
|
and not default_media_io_kwargs
|
|
and not default_mm_processor_kwargs
|
|
):
|
|
return self
|
|
|
|
return ChatParams(
|
|
chat_template=self.chat_template,
|
|
chat_template_content_format=self.chat_template_content_format,
|
|
chat_template_kwargs=merge_kwargs(
|
|
default_chat_template_kwargs,
|
|
self.chat_template_kwargs,
|
|
),
|
|
media_io_kwargs=merge_media_io_kwargs(
|
|
default_media_io_kwargs,
|
|
self.media_io_kwargs,
|
|
),
|
|
mm_processor_kwargs=recursively_merge_kwargs(
|
|
default_mm_processor_kwargs,
|
|
self.mm_processor_kwargs,
|
|
),
|
|
)
|
|
|
|
def get_apply_chat_template_kwargs(self) -> dict[str, Any]:
|
|
"""The arguments to pass to `tokenizer.apply_chat_template`."""
|
|
return merge_kwargs(
|
|
self.chat_template_kwargs,
|
|
dict(chat_template=self.chat_template, return_dict=False),
|
|
)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class TokenizeParams:
|
|
"""Configuration to control how prompts are tokenized."""
|
|
|
|
max_total_tokens: int | None
|
|
"""
|
|
Maximum allowed number of input + output tokens.
|
|
|
|
Usually, this refers to the model's context length.
|
|
"""
|
|
|
|
max_output_tokens: int = 0
|
|
"""Maximum requested number of output tokens."""
|
|
|
|
pad_prompt_tokens: int | None = None
|
|
"""
|
|
Number of tokens to pad to:
|
|
- `None` means no padding.
|
|
- `-1` maps to `max_input_tokens`.
|
|
"""
|
|
|
|
truncate_prompt_tokens: int | None = None
|
|
"""
|
|
Number of tokens to keep:
|
|
- `None` means no truncation.
|
|
- `-1` maps to `max_input_tokens`.
|
|
"""
|
|
|
|
truncation_side: Literal["left", "right"] | None = None
|
|
"""
|
|
Which side to truncate from when ``truncate_prompt_tokens`` is active:
|
|
- ``"right"`` keeps the first N tokens (truncate from the end).
|
|
- ``"left"`` keeps the last N tokens (truncate from the start).
|
|
- ``None`` falls back to the tokenizer default.
|
|
"""
|
|
|
|
do_lower_case: bool = False
|
|
"""Whether to normalize text to lower case before tokenization."""
|
|
|
|
add_special_tokens: bool = True
|
|
"""Whether to add special tokens."""
|
|
|
|
needs_detokenization: bool = False
|
|
"""
|
|
Whether the tokenized prompt needs to contain the original text.
|
|
|
|
Not to be confused with `SamplingParams.detokenize` which deals
|
|
with the output generated by the model.
|
|
"""
|
|
|
|
max_total_tokens_param: str = "max_total_tokens"
|
|
"""Override this to edit the message for validation errors."""
|
|
|
|
max_output_tokens_param: str = "max_output_tokens"
|
|
"""Override this to edit the message for validation errors."""
|
|
|
|
truncate_prompt_tokens_param: str = "truncate_prompt_tokens"
|
|
"""Override this to edit the message for validation errors."""
|
|
|
|
@property
|
|
def max_input_tokens(self) -> int | None:
|
|
"""Maximum allowed number of input tokens."""
|
|
if self.max_total_tokens is None:
|
|
return None
|
|
|
|
return self.max_total_tokens - self.max_output_tokens
|
|
|
|
def __post_init__(self) -> None:
|
|
max_total_tokens = self.max_total_tokens
|
|
max_output_tokens = self.max_output_tokens
|
|
max_input_tokens = self.max_input_tokens
|
|
truncate_prompt_tokens = self.truncate_prompt_tokens
|
|
|
|
if (
|
|
max_output_tokens is not None
|
|
and max_total_tokens is not None
|
|
and max_output_tokens > max_total_tokens
|
|
):
|
|
raise VLLMValidationError(
|
|
f"{self.max_output_tokens_param}={max_output_tokens}"
|
|
f"cannot be greater than "
|
|
f"{self.max_total_tokens_param}={max_total_tokens=}. "
|
|
f"Please request fewer output tokens.",
|
|
parameter=self.max_output_tokens_param,
|
|
value=max_output_tokens,
|
|
)
|
|
|
|
if (
|
|
max_input_tokens is not None
|
|
and truncate_prompt_tokens is not None
|
|
and truncate_prompt_tokens > max_input_tokens
|
|
):
|
|
raise VLLMValidationError(
|
|
f"{self.truncate_prompt_tokens_param}={truncate_prompt_tokens} "
|
|
f"cannot be greater than {self.max_total_tokens_param} - "
|
|
f"{self.max_output_tokens_param} = {max_input_tokens}. "
|
|
f"Please request a smaller truncation size.",
|
|
parameter=self.truncate_prompt_tokens_param,
|
|
value=truncate_prompt_tokens,
|
|
)
|
|
|
|
def with_kwargs(self, **tokenization_kwargs: Any):
|
|
max_length = tokenization_kwargs.pop("max_length", self.max_input_tokens)
|
|
pad_prompt_tokens = tokenization_kwargs.pop(
|
|
"pad_prompt_tokens", self.pad_prompt_tokens
|
|
)
|
|
truncate_prompt_tokens = tokenization_kwargs.pop(
|
|
"truncate_prompt_tokens", self.truncate_prompt_tokens
|
|
)
|
|
do_lower_case = tokenization_kwargs.pop("do_lower_case", self.do_lower_case)
|
|
add_special_tokens = tokenization_kwargs.pop(
|
|
"add_special_tokens", self.add_special_tokens
|
|
)
|
|
needs_detokenization = tokenization_kwargs.pop(
|
|
"needs_detokenization", self.needs_detokenization
|
|
)
|
|
|
|
# https://huggingface.co/docs/transformers/en/pad_truncation
|
|
if padding := tokenization_kwargs.pop("padding", None):
|
|
if padding == "max_length":
|
|
pad_prompt_tokens = max_length
|
|
elif padding in (False, "do_not_pad"):
|
|
pad_prompt_tokens = None
|
|
else:
|
|
# To emit the below warning
|
|
tokenization_kwargs["padding"] = padding
|
|
|
|
if truncation := tokenization_kwargs.pop("truncation", None):
|
|
if truncation in (True, "longest_first"):
|
|
truncate_prompt_tokens = max_length
|
|
elif truncation in (False, "do_not_truncate"):
|
|
truncate_prompt_tokens = None
|
|
else:
|
|
# To emit the below warning
|
|
tokenization_kwargs["truncation"] = truncation
|
|
|
|
if tokenization_kwargs:
|
|
logger.warning(
|
|
"The following tokenization arguments are not supported "
|
|
"by vLLM Renderer and will be ignored: %s",
|
|
tokenization_kwargs,
|
|
)
|
|
|
|
max_total_tokens = self.max_total_tokens
|
|
|
|
return TokenizeParams(
|
|
max_total_tokens=max_total_tokens,
|
|
max_output_tokens=(
|
|
0
|
|
if max_total_tokens is None or max_length is None
|
|
else max_total_tokens - max_length
|
|
),
|
|
pad_prompt_tokens=pad_prompt_tokens,
|
|
truncate_prompt_tokens=truncate_prompt_tokens,
|
|
truncation_side=self.truncation_side,
|
|
do_lower_case=do_lower_case,
|
|
add_special_tokens=add_special_tokens,
|
|
needs_detokenization=needs_detokenization,
|
|
)
|
|
|
|
def get_encode_kwargs(self) -> dict[str, Any]:
|
|
"""The arguments to pass to `tokenizer.encode`."""
|
|
max_length = self.truncate_prompt_tokens
|
|
if max_length is not None and max_length < 0:
|
|
max_length = self.max_input_tokens
|
|
elif max_length is None and self.max_input_tokens is not None:
|
|
# This prevents tokenization from taking up more resources than necessary
|
|
# while still failing `self._token_len_check` as expected by users
|
|
max_length = self.max_input_tokens + 1
|
|
|
|
# Left-side truncation requires the full token sequence so we can
|
|
# slice from the end in _token_truncation. Disable HF-level
|
|
# truncation (which would incorrectly truncate from the right for
|
|
# pooling models) and let _token_truncation handle it.
|
|
if self.truncation_side == "left":
|
|
return dict(
|
|
truncation=False,
|
|
add_special_tokens=self.add_special_tokens,
|
|
)
|
|
|
|
return dict(
|
|
truncation=max_length is not None,
|
|
max_length=max_length,
|
|
add_special_tokens=self.add_special_tokens,
|
|
)
|
|
|
|
def _text_len_check(self, tokenizer: TokenizerLike | None, text: str) -> str:
|
|
"""Apply length checks to prompt text if necessary."""
|
|
max_input_tokens = self.max_input_tokens
|
|
if max_input_tokens is None:
|
|
return text
|
|
|
|
if self.truncate_prompt_tokens is None and tokenizer is not None:
|
|
max_input_chars = max_input_tokens * tokenizer.max_chars_per_token
|
|
|
|
if len(text) > max_input_chars:
|
|
# To save resources, fail the request outright without even
|
|
# attempting tokenization
|
|
raise VLLMValidationError(
|
|
f"This model's maximum context length is "
|
|
f"{self.max_total_tokens} tokens. However, you requested "
|
|
f"{self.max_output_tokens} output tokens and your prompt "
|
|
f"contains {len(text)} characters (more than "
|
|
f"{max_input_chars} characters, which is the upper bound "
|
|
f"for {max_input_tokens} input tokens). "
|
|
f"Please reduce the length of the input prompt or the "
|
|
f"number of requested output tokens.",
|
|
parameter="input_text",
|
|
value=len(text),
|
|
)
|
|
|
|
return text
|
|
|
|
def _text_lowercase(self, tokenizer: TokenizerLike | None, text: str) -> str:
|
|
"""Apply lowercase to prompt text if necessary."""
|
|
return text.lower() if self.do_lower_case else text
|
|
|
|
def _validate_text(self, tokenizer: TokenizerLike | None, text: str) -> str:
|
|
"""Apply all validators to prompt text."""
|
|
for validator in (
|
|
self._text_len_check,
|
|
self._text_lowercase,
|
|
):
|
|
text = validator(tokenizer, text)
|
|
|
|
return text
|
|
|
|
def apply_pre_tokenization(
|
|
self,
|
|
tokenizer: TokenizerLike | None,
|
|
prompt: TextPrompt,
|
|
) -> TextPrompt:
|
|
"""
|
|
Ensure that the prompt meets the requirements set out by this config.
|
|
If that is not possible, raise a `VLLMValidationError`.
|
|
|
|
This method is run before tokenization occurs.
|
|
"""
|
|
prompt["prompt"] = self._validate_text(tokenizer, prompt["prompt"])
|
|
|
|
return prompt
|
|
|
|
def _token_padding(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
|
|
"""Apply padding to prompt tokens if necessary."""
|
|
pad_length = self.pad_prompt_tokens
|
|
if pad_length is not None and pad_length < 0:
|
|
pad_length = self.max_input_tokens
|
|
|
|
if pad_length is None or pad_length <= len(tokens):
|
|
return tokens
|
|
|
|
if tokenizer is None:
|
|
raise ValueError("Cannot pad tokens when `skip_tokenizer_init=True`")
|
|
if not isinstance(tokens, list):
|
|
raise ValueError("Cannot pad tokens for embedding inputs")
|
|
|
|
return tokens + [tokenizer.pad_token_id] * (pad_length - len(tokens))
|
|
|
|
def _token_truncation(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
|
|
"""Apply truncation to prompt tokens if necessary."""
|
|
max_length = self.truncate_prompt_tokens
|
|
if max_length is not None and max_length < 0:
|
|
max_length = self.max_input_tokens
|
|
|
|
if max_length is None or max_length >= len(tokens):
|
|
return tokens
|
|
if max_length == 0:
|
|
return tokens[:0]
|
|
|
|
side = self.truncation_side or (
|
|
tokenizer.truncation_side if tokenizer is not None else None
|
|
)
|
|
if side == "left":
|
|
return tokens[-max_length:]
|
|
|
|
return tokens[:max_length]
|
|
|
|
def _token_len_check(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
|
|
"""Apply length checks to prompt tokens if necessary."""
|
|
max_input_tokens = self.max_input_tokens
|
|
if max_input_tokens is None:
|
|
return tokens
|
|
|
|
if len(tokens) > max_input_tokens:
|
|
token_count = len(tokens)
|
|
# The tokenizer may have truncated the prompt to
|
|
# max_input_tokens + 1 (see get_encode_kwargs), so the
|
|
# actual prompt length could be larger.
|
|
qualifier = "at least " if token_count == max_input_tokens + 1 else ""
|
|
total = token_count + self.max_output_tokens
|
|
raise VLLMValidationError(
|
|
f"This model's maximum context length is "
|
|
f"{self.max_total_tokens} tokens. However, you requested "
|
|
f"{self.max_output_tokens} output tokens and your prompt "
|
|
f"contains {qualifier}{token_count} input tokens, "
|
|
f"for a total of {qualifier}{total} tokens. "
|
|
f"Please reduce the length of the input prompt or the "
|
|
f"number of requested output tokens.",
|
|
parameter="input_tokens",
|
|
value=token_count,
|
|
)
|
|
|
|
return tokens
|
|
|
|
def _validate_tokens(self, tokenizer: TokenizerLike | None, tokens: _S) -> _S:
|
|
"""Apply all validators to a token sequence."""
|
|
for validator in (
|
|
self._token_padding,
|
|
self._token_truncation,
|
|
self._token_len_check,
|
|
):
|
|
tokens = validator(tokenizer, tokens)
|
|
|
|
return tokens
|
|
|
|
def apply_post_tokenization(
|
|
self,
|
|
tokenizer: TokenizerLike | None,
|
|
prompt: TokensPrompt | EmbedsPrompt,
|
|
) -> TokensPrompt | EmbedsPrompt:
|
|
"""
|
|
Ensure that the prompt meets the requirements set out by this config.
|
|
If that is not possible, raise a `VLLMValidationError`.
|
|
|
|
This method is run after tokenization occurs.
|
|
"""
|
|
if "prompt_token_ids" in prompt:
|
|
prompt["prompt_token_ids"] = self._validate_tokens( # type: ignore[typeddict-unknown-key]
|
|
tokenizer,
|
|
prompt["prompt_token_ids"], # type: ignore[typeddict-item]
|
|
)
|
|
if "prompt_embeds" in prompt:
|
|
prompt["prompt_embeds"] = self._validate_tokens( # type: ignore[typeddict-unknown-key]
|
|
tokenizer,
|
|
prompt["prompt_embeds"], # type: ignore[typeddict-item]
|
|
)
|
|
|
|
return prompt
|