[Core] Registry for processing model inputs (#5214)
Co-authored-by: ywang96 <ywang@roblox.com>
This commit is contained in:
19
vllm/inputs/__init__.py
Normal file
19
vllm/inputs/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from .data import (LLMInputs, ParsedText, ParsedTokens, PromptInputs,
|
||||
PromptStrictInputs, TextPrompt, TextTokensPrompt,
|
||||
TokensPrompt, parse_and_batch_prompt)
|
||||
from .registry import InputContext, InputRegistry
|
||||
|
||||
INPUT_REGISTRY = InputRegistry()
|
||||
"""
|
||||
The global :class:`~InputRegistry` which is used by :class:`~vllm.LLMEngine`
|
||||
to dispatch data processing according to the target model.
|
||||
|
||||
See also:
|
||||
:ref:`input_processing_pipeline`
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"ParsedText", "ParsedTokens", "parse_and_batch_prompt", "TextPrompt",
|
||||
"TokensPrompt", "TextTokensPrompt", "PromptStrictInputs", "PromptInputs",
|
||||
"LLMInputs", "INPUT_REGISTRY", "InputContext", "InputRegistry"
|
||||
]
|
||||
144
vllm/inputs/data.py
Normal file
144
vllm/inputs/data.py
Normal file
@@ -0,0 +1,144 @@
|
||||
from typing import (TYPE_CHECKING, List, Literal, Optional, Sequence,
|
||||
TypedDict, Union, cast, overload)
|
||||
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.multimodal import MultiModalData
|
||||
|
||||
|
||||
class ParsedText(TypedDict):
|
||||
content: str
|
||||
is_tokens: Literal[False]
|
||||
|
||||
|
||||
class ParsedTokens(TypedDict):
|
||||
content: List[int]
|
||||
is_tokens: Literal[True]
|
||||
|
||||
|
||||
# https://github.com/vllm-project/vllm/pull/4028
|
||||
@overload
|
||||
def parse_and_batch_prompt(
|
||||
prompt: Union[str, List[str]]) -> Sequence[ParsedText]:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def parse_and_batch_prompt(
|
||||
prompt: Union[List[int], List[List[int]]]) -> Sequence[ParsedTokens]:
|
||||
...
|
||||
|
||||
|
||||
def parse_and_batch_prompt(
|
||||
prompt: Union[str, List[str], List[int], List[List[int]]],
|
||||
) -> Union[Sequence[ParsedText], Sequence[ParsedTokens]]:
|
||||
if isinstance(prompt, str):
|
||||
# case 1: a string
|
||||
return [ParsedText(content=prompt, is_tokens=False)]
|
||||
|
||||
if isinstance(prompt, list):
|
||||
if len(prompt) == 0:
|
||||
raise ValueError("please provide at least one prompt")
|
||||
|
||||
if isinstance(prompt[0], str):
|
||||
# case 2: array of strings
|
||||
return [
|
||||
ParsedText(content=elem, is_tokens=False)
|
||||
for elem in cast(List[str], prompt)
|
||||
]
|
||||
if isinstance(prompt[0], int):
|
||||
# case 3: array of tokens
|
||||
elem = cast(List[int], prompt)
|
||||
return [ParsedTokens(content=elem, is_tokens=True)]
|
||||
if isinstance(prompt[0], list):
|
||||
if len(prompt[0]) == 0:
|
||||
raise ValueError("please provide at least one prompt")
|
||||
|
||||
if isinstance(prompt[0][0], int):
|
||||
# case 4: array of token arrays
|
||||
return [
|
||||
ParsedTokens(content=elem, is_tokens=True)
|
||||
for elem in cast(List[List[int]], prompt)
|
||||
]
|
||||
|
||||
raise ValueError("prompt must be a string, array of strings, "
|
||||
"array of tokens, or array of token arrays")
|
||||
|
||||
|
||||
class TextPrompt(TypedDict):
|
||||
"""Schema for a text prompt."""
|
||||
|
||||
prompt: str
|
||||
"""The input text to be tokenized before passing to the model."""
|
||||
|
||||
multi_modal_data: NotRequired["MultiModalData"]
|
||||
"""
|
||||
Optional multi-modal data to pass to the model,
|
||||
if the model supports it.
|
||||
"""
|
||||
|
||||
|
||||
class TokensPrompt(TypedDict):
|
||||
"""Schema for a tokenized prompt."""
|
||||
|
||||
prompt_token_ids: List[int]
|
||||
"""A list of token IDs to pass to the model."""
|
||||
|
||||
multi_modal_data: NotRequired["MultiModalData"]
|
||||
"""
|
||||
Optional multi-modal data to pass to the model,
|
||||
if the model supports it.
|
||||
"""
|
||||
|
||||
|
||||
class TextTokensPrompt(TypedDict):
|
||||
"""It is assumed that :attr:`prompt` is consistent with
|
||||
:attr:`prompt_token_ids`. This is currently used in
|
||||
:class:`AsyncLLMEngine` for logging both the text and token IDs."""
|
||||
|
||||
prompt: str
|
||||
"""The prompt text."""
|
||||
|
||||
prompt_token_ids: List[int]
|
||||
"""The token IDs of the prompt."""
|
||||
|
||||
multi_modal_data: NotRequired["MultiModalData"]
|
||||
"""
|
||||
Optional multi-modal data to pass to the model,
|
||||
if the model supports it.
|
||||
"""
|
||||
|
||||
|
||||
PromptStrictInputs = Union[str, TextPrompt, TokensPrompt]
|
||||
"""
|
||||
The inputs to the LLM, which can take one of the following forms:
|
||||
|
||||
- A text prompt (:class:`str` or :class:`TextPrompt`)
|
||||
- A tokenized prompt (:class:`TokensPrompt`)
|
||||
"""
|
||||
|
||||
PromptInputs = Union[str, TextPrompt, TokensPrompt, TextTokensPrompt]
|
||||
"""Same as :const:`PromptStrictInputs` but additionally accepts
|
||||
:class:`TextTokensPrompt`."""
|
||||
|
||||
|
||||
class LLMInputs(TypedDict):
|
||||
"""
|
||||
The inputs in :class:`~vllm.LLMEngine` before they are
|
||||
passed to the model executor.
|
||||
"""
|
||||
|
||||
prompt_token_ids: List[int]
|
||||
"""The token IDs of the prompt."""
|
||||
|
||||
prompt: NotRequired[Optional[str]]
|
||||
"""
|
||||
The original prompt text corresponding to the token IDs, if available.
|
||||
"""
|
||||
|
||||
multi_modal_data: NotRequired[Optional["MultiModalData"]]
|
||||
"""
|
||||
Optional multi-modal data to pass to the model,
|
||||
if the model supports it.
|
||||
"""
|
||||
207
vllm/inputs/registry.py
Normal file
207
vllm/inputs/registry.py
Normal file
@@ -0,0 +1,207 @@
|
||||
import functools
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Callable, Dict, Optional, Tuple, Type,
|
||||
TypeVar)
|
||||
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .data import LLMInputs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, VisionLanguageConfig
|
||||
from vllm.multimodal import MultiModalData
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
C = TypeVar("C", bound=PretrainedConfig)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class InputContext:
|
||||
"""
|
||||
Contains information about the model which may be used to
|
||||
modify the inputs.
|
||||
"""
|
||||
|
||||
model_config: "ModelConfig"
|
||||
"""The configuration of the model."""
|
||||
|
||||
def get_multimodal_config(self) -> "VisionLanguageConfig":
|
||||
"""
|
||||
Get the multimodal configuration of the model.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model is not multimodal.
|
||||
"""
|
||||
|
||||
multimodal_config = self.model_config.multimodal_config
|
||||
if multimodal_config is None:
|
||||
raise ValueError("No multimodal config found")
|
||||
|
||||
return multimodal_config
|
||||
|
||||
def get_hf_config(self, hf_config_type: Type[C]) -> C:
|
||||
"""
|
||||
Get the HuggingFace configuration
|
||||
(:class:`transformers.PretrainedConfig`) of the model,
|
||||
additionally checking its type.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model is not of the specified type.
|
||||
"""
|
||||
|
||||
hf_config = self.model_config.hf_config
|
||||
if not isinstance(hf_config, hf_config_type):
|
||||
raise TypeError("Invalid type of HuggingFace config. "
|
||||
f"Expected type: {hf_config_type}, but "
|
||||
f"found type: {type(hf_config)}")
|
||||
|
||||
return hf_config
|
||||
|
||||
|
||||
N = TypeVar("N", bound=Type[nn.Module])
|
||||
|
||||
DummyDataFactory = Callable[[InputContext, int],
|
||||
Tuple["SequenceData", Optional["MultiModalData"]]]
|
||||
"""
|
||||
Create dummy data to be inputted into the model.
|
||||
|
||||
Note:
|
||||
:data:`InputProcessor` is not applied to the dummy data.
|
||||
"""
|
||||
|
||||
InputProcessor = Callable[[InputContext, LLMInputs], LLMInputs]
|
||||
"""Preprocess the inputs to the model."""
|
||||
|
||||
|
||||
class InputRegistry:
|
||||
"""
|
||||
A registry to dispatch data processing
|
||||
according to the target model.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._dummy_factories_by_model_type: Dict[Type[nn.Module],
|
||||
DummyDataFactory] = {}
|
||||
self._input_processors_by_model_type: Dict[Type[nn.Module],
|
||||
InputProcessor] = {}
|
||||
|
||||
def _default_dummy_data_factory(
|
||||
self,
|
||||
ctx: InputContext,
|
||||
seq_len: int,
|
||||
) -> Tuple["SequenceData", Optional["MultiModalData"]]:
|
||||
"""
|
||||
The default dummy data factory represents the longest possible text
|
||||
that can be inputted to the model.
|
||||
|
||||
Note:
|
||||
:data:`InputProcessor` is not applied to the dummy data.
|
||||
"""
|
||||
# Avoid circular import
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
dummy_seq_data = SequenceData([0] * seq_len)
|
||||
dummy_multi_modal_data = None
|
||||
|
||||
return dummy_seq_data, dummy_multi_modal_data
|
||||
|
||||
def register_dummy_data(self, factory: DummyDataFactory):
|
||||
"""
|
||||
Register a dummy data factory to a model class.
|
||||
|
||||
During memory profiling, the provided function is invoked to create
|
||||
dummy data to be inputted into the model. The resulting memory usage
|
||||
should be an upper bound of what the model would use at inference time.
|
||||
"""
|
||||
|
||||
def wrapper(model_cls: N) -> N:
|
||||
if model_cls in self._dummy_factories_by_model_type:
|
||||
logger.warning(
|
||||
"Model class %s already has dummy data "
|
||||
"registered to %s. It is overwritten by the new one.",
|
||||
model_cls, self)
|
||||
|
||||
self._dummy_factories_by_model_type[model_cls] = factory
|
||||
|
||||
return model_cls
|
||||
|
||||
return wrapper
|
||||
|
||||
def dummy_data_for_profiling(self, model_config: "ModelConfig",
|
||||
seq_len: int):
|
||||
"""
|
||||
Create dummy data for profiling the memory usage of a model.
|
||||
|
||||
The model is identified by ``model_config``.
|
||||
|
||||
TODO: Add guide [ref: PR #5276]
|
||||
"""
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.model_loader import get_model_architecture
|
||||
|
||||
model_cls, _ = get_model_architecture(model_config)
|
||||
dummy_factory = self._dummy_factories_by_model_type \
|
||||
.get(model_cls, self._default_dummy_data_factory)
|
||||
|
||||
return dummy_factory(InputContext(model_config), seq_len)
|
||||
|
||||
def _default_input_processor(self, ctx: InputContext,
|
||||
inputs: LLMInputs) -> LLMInputs:
|
||||
"""The default input processor is a no-op."""
|
||||
return inputs
|
||||
|
||||
def register_input_processor(self, processor: InputProcessor):
|
||||
"""
|
||||
Register an input processor to a model class.
|
||||
|
||||
The provided function is invoked on each input to the model. This
|
||||
happens before :meth:`~vllm.multimodal.MultiModalRegistry.map_input`.
|
||||
|
||||
See also:
|
||||
:ref:`input_processing_pipeline`
|
||||
"""
|
||||
|
||||
def wrapper(model_cls: N) -> N:
|
||||
if model_cls in self._input_processors_by_model_type:
|
||||
logger.warning(
|
||||
"Model class %s already has input processor "
|
||||
"registered to %s. It is overwritten by the new one.",
|
||||
model_cls, self)
|
||||
|
||||
self._input_processors_by_model_type[model_cls] = processor
|
||||
|
||||
return model_cls
|
||||
|
||||
return wrapper
|
||||
|
||||
def process_input(self, model_config: "ModelConfig",
|
||||
inputs: LLMInputs) -> LLMInputs:
|
||||
"""
|
||||
Apply an input processor to an instance of model inputs.
|
||||
|
||||
The model is identified by ``model_config``.
|
||||
|
||||
See also:
|
||||
:ref:`input_processing_pipeline`
|
||||
"""
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.model_loader import get_model_architecture
|
||||
|
||||
model_cls, _ = get_model_architecture(model_config)
|
||||
|
||||
processor = self._input_processors_by_model_type \
|
||||
.get(model_cls, self._default_input_processor)
|
||||
|
||||
return processor(InputContext(model_config), inputs)
|
||||
|
||||
def create_input_processor(self, model_config: "ModelConfig"):
|
||||
"""
|
||||
Create an input processor (see :meth:`process_input`) for a
|
||||
specific model.
|
||||
"""
|
||||
return functools.partial(self.process_input, model_config)
|
||||
Reference in New Issue
Block a user