[Core] [Bugfix] Add Input Embeddings (#15428)

Signed-off-by: Andrew Sansom <andrew@protopia.ai>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: 临景 <linjing.yx@alibaba-inc.com>
Co-authored-by: Bryce1010 <bryceyx@gmail.com>
Co-authored-by: Nan2018 <nan@protopia.ai>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Andrew Sansom
2025-05-02 03:06:39 -05:00
committed by GitHub
parent 9e2de9b9e9
commit cc2a77d7f1
22 changed files with 691 additions and 113 deletions

View File

@@ -1,9 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
from .data import (DecoderOnlyInputs, EmbedsInputs, EncoderDecoderInputs,
ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs,
TokensPrompt, build_explicit_enc_dec_prompt,
TokensPrompt, build_explicit_enc_dec_prompt, embeds_inputs,
to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts)
from .registry import (DummyData, InputContext, InputProcessingContext,
InputRegistry)
@@ -21,7 +21,9 @@ __all__ = [
"SingletonPrompt",
"ExplicitEncoderDecoderPrompt",
"TokenInputs",
"EmbedsInputs",
"token_inputs",
"embeds_inputs",
"DecoderOnlyInputs",
"EncoderDecoderInputs",
"ProcessorInputs",

View File

@@ -2,6 +2,7 @@
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast
import torch
from typing_extensions import NotRequired, TypedDict, TypeVar
if TYPE_CHECKING:
@@ -63,12 +64,20 @@ class TokensPrompt(TypedDict):
"""
SingletonPrompt = Union[str, TextPrompt, TokensPrompt]
class EmbedsPrompt(TypedDict):
"""Schema for a prompt provided via token embeddings."""
prompt_embeds: torch.Tensor
"""The embeddings of the prompt."""
SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt]
"""
Set of possible schemas for a single prompt:
- A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`)
- An embeddings prompt (:class:`EmbedsPrompt`)
Note that "singleton" is as opposed to a data structure
which encapsulates multiple prompts, i.e. of the sort
@@ -129,6 +138,7 @@ both decoder-only and encoder/decoder input types:
- A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`)
- An embeddings prompt (:class:`EmbedsPrompt`)
- A single data structure containing both an encoder and a decoder prompt
(:class:`ExplicitEncoderDecoderPrompt`)
"""
@@ -176,7 +186,27 @@ def token_inputs(
return inputs
DecoderOnlyInputs = Union[TokenInputs, "MultiModalInputs"]
class EmbedsInputs(TypedDict):
"""Represents embeddings-based inputs."""
type: Literal["embeds"]
"""The type of inputs."""
prompt_embeds: torch.Tensor
"""The embeddings of the prompt."""
def embeds_inputs(prompt_embeds: torch.Tensor) -> EmbedsInputs:
"""Construct :class:`EmbedsInputs` from optional values."""
inputs = EmbedsInputs(
type="embeds",
prompt_embeds=prompt_embeds,
)
return inputs
DecoderOnlyInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"]
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
@@ -198,7 +228,7 @@ class EncoderDecoderInputs(TypedDict):
"""The inputs for the decoder portion."""
SingletonInputs = Union[TokenInputs, "MultiModalInputs"]
SingletonInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"]
"""
A processed :class:`SingletonPrompt` which can be passed to
:class:`vllm.sequence.Sequence`.

View File

@@ -6,8 +6,9 @@ from typing_extensions import TypeIs
from vllm.utils import is_list_of
from .data import (ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
SingletonInputs, SingletonPrompt, TextPrompt, TokensPrompt)
from .data import (EmbedsInputs, EmbedsPrompt, ExplicitEncoderDecoderPrompt,
ProcessorInputs, PromptType, SingletonInputs,
SingletonPrompt, TextPrompt, TokensPrompt)
class ParsedText(TypedDict):
@@ -84,30 +85,69 @@ class ParsedTokensPrompt(TypedDict):
content: TokensPrompt
class ParsedEmbedsPrompt(TypedDict):
type: Literal['embeds']
content: EmbedsPrompt
@overload
def parse_singleton_prompt(prompt: str) -> ParsedStrPrompt:
...
@overload
def parse_singleton_prompt(prompt: TextPrompt) -> ParsedTextPrompt:
...
@overload
def parse_singleton_prompt(prompt: TokensPrompt) -> ParsedTokensPrompt:
...
@overload
def parse_singleton_prompt(prompt: EmbedsPrompt) -> ParsedEmbedsPrompt:
...
def parse_singleton_prompt(
prompt: SingletonPrompt,
) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]:
) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt,
ParsedEmbedsPrompt]:
if isinstance(prompt, str):
return ParsedStrPrompt(type="str", content=prompt)
elif isinstance(prompt, dict):
if "prompt_token_ids" in prompt:
return ParsedTokensPrompt(type="tokens",
content=prompt) # type: ignore
# Type ignores are because mypy does not correctly infer the TypedDicts
# Pyright does succeed.
if "prompt_embeds" in prompt:
return ParsedEmbedsPrompt(
type="embeds", content=prompt) # type: ignore[typeddict-item]
elif "prompt_token_ids" in prompt:
return ParsedTokensPrompt(
type="tokens", content=prompt) # type: ignore[typeddict-item]
elif "prompt" in prompt:
return ParsedTextPrompt(type="text", content=prompt)
raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt")
raise TypeError(
"inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt")
def is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]:
return isinstance(prompt, dict) and "prompt_token_ids" in prompt
def is_embeds_prompt(prompt: PromptType) -> TypeIs[EmbedsPrompt]:
return isinstance(prompt, dict) and "prompt_embeds" in prompt
def is_explicit_encoder_decoder_prompt(
prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]:
return isinstance(prompt, dict) and "encoder_prompt" in prompt
def is_embeds_inputs(inputs: SingletonInputs) -> TypeIs[EmbedsInputs]:
return isinstance(inputs, dict) and inputs["type"] == "embeds"
def split_enc_dec_inputs(
inputs: ProcessorInputs,
) -> tuple[Optional[SingletonInputs], SingletonInputs]:

View File

@@ -6,6 +6,7 @@ from typing import Any, Optional, Union, cast
from typing_extensions import assert_never
from vllm import envs
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@@ -15,9 +16,11 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs,
PromptType, SingletonInputs, SingletonPrompt, token_inputs)
from .parse import (ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt,
from .data import (DecoderOnlyInputs, EmbedsInputs, EncoderDecoderInputs,
ProcessorInputs, PromptType, SingletonInputs,
SingletonPrompt, TokenInputs, embeds_inputs, token_inputs)
from .parse import (ParsedEmbedsPrompt, ParsedStrPrompt, ParsedTextPrompt,
ParsedTokensPrompt, is_embeds_inputs,
is_explicit_encoder_decoder_prompt, parse_singleton_prompt)
logger = init_logger(__name__)
@@ -328,6 +331,10 @@ class InputPreprocessor:
* :class:`SingletonInputs` instance
"""
parsed = parse_singleton_prompt(prompt)
if parsed["type"] == "embeds":
return self._process_prompt_embeds(parsed)
prompt_text, prompt_token_ids, token_type_ids, cache_salt = \
self._get_prompt_data(parsed)
@@ -359,6 +366,8 @@ class InputPreprocessor:
cache_salt=cache_salt,
)
assert_never(parsed)
async def _prompt_to_llm_inputs_async(
self,
prompt: SingletonPrompt,
@@ -369,6 +378,9 @@ class InputPreprocessor:
"""Async version of :meth:`_extract_prompt_components`."""
parsed = parse_singleton_prompt(prompt)
if parsed["type"] == "embeds":
return self._process_prompt_embeds(parsed)
prompt_text, prompt_token_ids, token_type_ids, cache_salt = \
self._get_prompt_data(parsed)
@@ -399,10 +411,34 @@ class InputPreprocessor:
cache_salt=cache_salt,
)
def _process_prompt_embeds(self,
parsed: ParsedEmbedsPrompt) -> EmbedsInputs:
if envs.VLLM_USE_V1:
raise ValueError("prompt_embeds is only available in V0.")
prompt_embeds_content = parsed["content"]
prompt_embeds = prompt_embeds_content["prompt_embeds"]
# prompt_embeds must be (seq_len, hidden_size), but if the user
# passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
# we can unambiguously process the intent by squeezing the batch
# dimension.
if prompt_embeds.ndim == 3 and prompt_embeds.shape[0] == 1:
prompt_embeds = prompt_embeds.squeeze(dim=0)
if prompt_embeds.ndim != 2:
raise ValueError(
"prompt_embeds must be of shape (seq_len, hidden_size).")
return embeds_inputs(prompt_embeds=prompt_embeds)
assert_never(parsed)
def _build_enc_dec_llm_inputs(
self,
encoder_inputs: SingletonInputs,
decoder_inputs: Optional[SingletonInputs],
encoder_inputs: Union[TokenInputs, MultiModalInputs],
decoder_inputs: Optional[Union[TokenInputs, MultiModalInputs]],
) -> EncoderDecoderInputs:
if (encoder_inputs["type"] == "token"
or encoder_inputs["type"] == "multimodal"):
@@ -410,6 +446,9 @@ class InputPreprocessor:
else:
assert_never(encoder_inputs) # type: ignore[arg-type]
# Mypy does not correctly infer that EmbedsInputs is impossible
assert "prompt_token_ids" in encoder_inputs
if decoder_inputs is None:
if self.model_config.hf_config.model_type == "whisper":
# For Whisper models, the text prompt should go to the decoder.
@@ -441,7 +480,8 @@ class InputPreprocessor:
def _separate_enc_dec_inputs_from_mm_processor_outputs(
self,
inputs: SingletonInputs,
decoder_inputs_to_override: Optional[SingletonInputs] = None,
decoder_inputs_to_override: Optional[Union[TokenInputs,
MultiModalInputs]] = None,
) -> tuple[SingletonInputs, SingletonInputs]:
"""
For encoder/decoder models only:
@@ -540,6 +580,8 @@ class InputPreprocessor:
# For multimodal model, override decoder prompt from processor
# with explicit decoder prompt.
if self.model_config.is_multimodal_model:
assert decoder_inputs is None or not is_embeds_inputs(
decoder_inputs)
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
encoder_inputs, decoder_inputs))
@@ -555,9 +597,12 @@ class InputPreprocessor:
inputs))
else:
encoder_inputs = inputs
decoder_inputs = None
# Mypy does not do type inference well with TypedDicts with Literal
# values.
assert not is_embeds_inputs(encoder_inputs)
assert decoder_inputs is None or not is_embeds_inputs(decoder_inputs)
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
async def _process_encoder_decoder_prompt_async(
@@ -590,6 +635,8 @@ class InputPreprocessor:
# For multimodal model, override decoder prompt from processor
# with explicit decoder prompt.
if self.model_config.is_multimodal_model:
assert decoder_inputs is None or not is_embeds_inputs(
decoder_inputs)
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
encoder_inputs, decoder_inputs))
@@ -605,9 +652,12 @@ class InputPreprocessor:
inputs))
else:
encoder_inputs = inputs
decoder_inputs = None
# Mypy does not do type inference well with TypedDicts with Literal
# values.
assert not is_embeds_inputs(encoder_inputs)
assert decoder_inputs is None or not is_embeds_inputs(decoder_inputs)
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
def _build_decoder_only_llm_inputs(
@@ -617,10 +667,15 @@ class InputPreprocessor:
) -> DecoderOnlyInputs:
if (prompt_inputs["type"] == "token"
or prompt_inputs["type"] == "multimodal"):
# Mypy does not do type inference well with typedicts and Literal
# values
assert not is_embeds_inputs(prompt_inputs)
prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter(
prompt_inputs["prompt_token_ids"],
prompt_adapter_request=prompt_adapter_request,
)
elif (prompt_inputs["type"] == "embeds"):
pass
else:
assert_never(prompt_inputs) # type: ignore[arg-type]