[Core] [Bugfix] Add Input Embeddings (#15428)
Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: 临景 <linjing.yx@alibaba-inc.com> Co-authored-by: Bryce1010 <bryceyx@gmail.com> Co-authored-by: Nan2018 <nan@protopia.ai> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
|
||||
from .data import (DecoderOnlyInputs, EmbedsInputs, EncoderDecoderInputs,
|
||||
ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
|
||||
SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs,
|
||||
TokensPrompt, build_explicit_enc_dec_prompt,
|
||||
TokensPrompt, build_explicit_enc_dec_prompt, embeds_inputs,
|
||||
to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts)
|
||||
from .registry import (DummyData, InputContext, InputProcessingContext,
|
||||
InputRegistry)
|
||||
@@ -21,7 +21,9 @@ __all__ = [
|
||||
"SingletonPrompt",
|
||||
"ExplicitEncoderDecoderPrompt",
|
||||
"TokenInputs",
|
||||
"EmbedsInputs",
|
||||
"token_inputs",
|
||||
"embeds_inputs",
|
||||
"DecoderOnlyInputs",
|
||||
"EncoderDecoderInputs",
|
||||
"ProcessorInputs",
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast
|
||||
|
||||
import torch
|
||||
from typing_extensions import NotRequired, TypedDict, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -63,12 +64,20 @@ class TokensPrompt(TypedDict):
|
||||
"""
|
||||
|
||||
|
||||
SingletonPrompt = Union[str, TextPrompt, TokensPrompt]
|
||||
class EmbedsPrompt(TypedDict):
|
||||
"""Schema for a prompt provided via token embeddings."""
|
||||
|
||||
prompt_embeds: torch.Tensor
|
||||
"""The embeddings of the prompt."""
|
||||
|
||||
|
||||
SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt]
|
||||
"""
|
||||
Set of possible schemas for a single prompt:
|
||||
|
||||
- A text prompt (:class:`str` or :class:`TextPrompt`)
|
||||
- A tokenized prompt (:class:`TokensPrompt`)
|
||||
- An embeddings prompt (:class:`EmbedsPrompt`)
|
||||
|
||||
Note that "singleton" is as opposed to a data structure
|
||||
which encapsulates multiple prompts, i.e. of the sort
|
||||
@@ -129,6 +138,7 @@ both decoder-only and encoder/decoder input types:
|
||||
|
||||
- A text prompt (:class:`str` or :class:`TextPrompt`)
|
||||
- A tokenized prompt (:class:`TokensPrompt`)
|
||||
- An embeddings prompt (:class:`EmbedsPrompt`)
|
||||
- A single data structure containing both an encoder and a decoder prompt
|
||||
(:class:`ExplicitEncoderDecoderPrompt`)
|
||||
"""
|
||||
@@ -176,7 +186,27 @@ def token_inputs(
|
||||
return inputs
|
||||
|
||||
|
||||
DecoderOnlyInputs = Union[TokenInputs, "MultiModalInputs"]
|
||||
class EmbedsInputs(TypedDict):
|
||||
"""Represents embeddings-based inputs."""
|
||||
|
||||
type: Literal["embeds"]
|
||||
"""The type of inputs."""
|
||||
|
||||
prompt_embeds: torch.Tensor
|
||||
"""The embeddings of the prompt."""
|
||||
|
||||
|
||||
def embeds_inputs(prompt_embeds: torch.Tensor) -> EmbedsInputs:
|
||||
"""Construct :class:`EmbedsInputs` from optional values."""
|
||||
inputs = EmbedsInputs(
|
||||
type="embeds",
|
||||
prompt_embeds=prompt_embeds,
|
||||
)
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
DecoderOnlyInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"]
|
||||
"""
|
||||
The inputs in :class:`~vllm.LLMEngine` before they are
|
||||
passed to the model executor.
|
||||
@@ -198,7 +228,7 @@ class EncoderDecoderInputs(TypedDict):
|
||||
"""The inputs for the decoder portion."""
|
||||
|
||||
|
||||
SingletonInputs = Union[TokenInputs, "MultiModalInputs"]
|
||||
SingletonInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"]
|
||||
"""
|
||||
A processed :class:`SingletonPrompt` which can be passed to
|
||||
:class:`vllm.sequence.Sequence`.
|
||||
|
||||
@@ -6,8 +6,9 @@ from typing_extensions import TypeIs
|
||||
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .data import (ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
|
||||
SingletonInputs, SingletonPrompt, TextPrompt, TokensPrompt)
|
||||
from .data import (EmbedsInputs, EmbedsPrompt, ExplicitEncoderDecoderPrompt,
|
||||
ProcessorInputs, PromptType, SingletonInputs,
|
||||
SingletonPrompt, TextPrompt, TokensPrompt)
|
||||
|
||||
|
||||
class ParsedText(TypedDict):
|
||||
@@ -84,30 +85,69 @@ class ParsedTokensPrompt(TypedDict):
|
||||
content: TokensPrompt
|
||||
|
||||
|
||||
class ParsedEmbedsPrompt(TypedDict):
|
||||
type: Literal['embeds']
|
||||
content: EmbedsPrompt
|
||||
|
||||
|
||||
@overload
|
||||
def parse_singleton_prompt(prompt: str) -> ParsedStrPrompt:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def parse_singleton_prompt(prompt: TextPrompt) -> ParsedTextPrompt:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def parse_singleton_prompt(prompt: TokensPrompt) -> ParsedTokensPrompt:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def parse_singleton_prompt(prompt: EmbedsPrompt) -> ParsedEmbedsPrompt:
|
||||
...
|
||||
|
||||
|
||||
def parse_singleton_prompt(
|
||||
prompt: SingletonPrompt,
|
||||
) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]:
|
||||
) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt,
|
||||
ParsedEmbedsPrompt]:
|
||||
if isinstance(prompt, str):
|
||||
return ParsedStrPrompt(type="str", content=prompt)
|
||||
elif isinstance(prompt, dict):
|
||||
if "prompt_token_ids" in prompt:
|
||||
return ParsedTokensPrompt(type="tokens",
|
||||
content=prompt) # type: ignore
|
||||
# Type ignores are because mypy does not correctly infer the TypedDicts
|
||||
# Pyright does succeed.
|
||||
if "prompt_embeds" in prompt:
|
||||
return ParsedEmbedsPrompt(
|
||||
type="embeds", content=prompt) # type: ignore[typeddict-item]
|
||||
elif "prompt_token_ids" in prompt:
|
||||
return ParsedTokensPrompt(
|
||||
type="tokens", content=prompt) # type: ignore[typeddict-item]
|
||||
elif "prompt" in prompt:
|
||||
return ParsedTextPrompt(type="text", content=prompt)
|
||||
|
||||
raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt")
|
||||
raise TypeError(
|
||||
"inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt")
|
||||
|
||||
|
||||
def is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]:
|
||||
return isinstance(prompt, dict) and "prompt_token_ids" in prompt
|
||||
|
||||
|
||||
def is_embeds_prompt(prompt: PromptType) -> TypeIs[EmbedsPrompt]:
|
||||
return isinstance(prompt, dict) and "prompt_embeds" in prompt
|
||||
|
||||
|
||||
def is_explicit_encoder_decoder_prompt(
|
||||
prompt: PromptType) -> TypeIs[ExplicitEncoderDecoderPrompt]:
|
||||
return isinstance(prompt, dict) and "encoder_prompt" in prompt
|
||||
|
||||
|
||||
def is_embeds_inputs(inputs: SingletonInputs) -> TypeIs[EmbedsInputs]:
|
||||
return isinstance(inputs, dict) and inputs["type"] == "embeds"
|
||||
|
||||
|
||||
def split_enc_dec_inputs(
|
||||
inputs: ProcessorInputs,
|
||||
) -> tuple[Optional[SingletonInputs], SingletonInputs]:
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any, Optional, Union, cast
|
||||
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@@ -15,9 +16,11 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
|
||||
from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs,
|
||||
PromptType, SingletonInputs, SingletonPrompt, token_inputs)
|
||||
from .parse import (ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt,
|
||||
from .data import (DecoderOnlyInputs, EmbedsInputs, EncoderDecoderInputs,
|
||||
ProcessorInputs, PromptType, SingletonInputs,
|
||||
SingletonPrompt, TokenInputs, embeds_inputs, token_inputs)
|
||||
from .parse import (ParsedEmbedsPrompt, ParsedStrPrompt, ParsedTextPrompt,
|
||||
ParsedTokensPrompt, is_embeds_inputs,
|
||||
is_explicit_encoder_decoder_prompt, parse_singleton_prompt)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -328,6 +331,10 @@ class InputPreprocessor:
|
||||
* :class:`SingletonInputs` instance
|
||||
"""
|
||||
parsed = parse_singleton_prompt(prompt)
|
||||
|
||||
if parsed["type"] == "embeds":
|
||||
return self._process_prompt_embeds(parsed)
|
||||
|
||||
prompt_text, prompt_token_ids, token_type_ids, cache_salt = \
|
||||
self._get_prompt_data(parsed)
|
||||
|
||||
@@ -359,6 +366,8 @@ class InputPreprocessor:
|
||||
cache_salt=cache_salt,
|
||||
)
|
||||
|
||||
assert_never(parsed)
|
||||
|
||||
async def _prompt_to_llm_inputs_async(
|
||||
self,
|
||||
prompt: SingletonPrompt,
|
||||
@@ -369,6 +378,9 @@ class InputPreprocessor:
|
||||
"""Async version of :meth:`_extract_prompt_components`."""
|
||||
parsed = parse_singleton_prompt(prompt)
|
||||
|
||||
if parsed["type"] == "embeds":
|
||||
return self._process_prompt_embeds(parsed)
|
||||
|
||||
prompt_text, prompt_token_ids, token_type_ids, cache_salt = \
|
||||
self._get_prompt_data(parsed)
|
||||
|
||||
@@ -399,10 +411,34 @@ class InputPreprocessor:
|
||||
cache_salt=cache_salt,
|
||||
)
|
||||
|
||||
def _process_prompt_embeds(self,
|
||||
parsed: ParsedEmbedsPrompt) -> EmbedsInputs:
|
||||
if envs.VLLM_USE_V1:
|
||||
raise ValueError("prompt_embeds is only available in V0.")
|
||||
|
||||
prompt_embeds_content = parsed["content"]
|
||||
|
||||
prompt_embeds = prompt_embeds_content["prompt_embeds"]
|
||||
|
||||
# prompt_embeds must be (seq_len, hidden_size), but if the user
|
||||
# passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
|
||||
# we can unambiguously process the intent by squeezing the batch
|
||||
# dimension.
|
||||
if prompt_embeds.ndim == 3 and prompt_embeds.shape[0] == 1:
|
||||
prompt_embeds = prompt_embeds.squeeze(dim=0)
|
||||
|
||||
if prompt_embeds.ndim != 2:
|
||||
raise ValueError(
|
||||
"prompt_embeds must be of shape (seq_len, hidden_size).")
|
||||
|
||||
return embeds_inputs(prompt_embeds=prompt_embeds)
|
||||
|
||||
assert_never(parsed)
|
||||
|
||||
def _build_enc_dec_llm_inputs(
|
||||
self,
|
||||
encoder_inputs: SingletonInputs,
|
||||
decoder_inputs: Optional[SingletonInputs],
|
||||
encoder_inputs: Union[TokenInputs, MultiModalInputs],
|
||||
decoder_inputs: Optional[Union[TokenInputs, MultiModalInputs]],
|
||||
) -> EncoderDecoderInputs:
|
||||
if (encoder_inputs["type"] == "token"
|
||||
or encoder_inputs["type"] == "multimodal"):
|
||||
@@ -410,6 +446,9 @@ class InputPreprocessor:
|
||||
else:
|
||||
assert_never(encoder_inputs) # type: ignore[arg-type]
|
||||
|
||||
# Mypy does not correctly infer that EmbedsInputs is impossible
|
||||
assert "prompt_token_ids" in encoder_inputs
|
||||
|
||||
if decoder_inputs is None:
|
||||
if self.model_config.hf_config.model_type == "whisper":
|
||||
# For Whisper models, the text prompt should go to the decoder.
|
||||
@@ -441,7 +480,8 @@ class InputPreprocessor:
|
||||
def _separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||
self,
|
||||
inputs: SingletonInputs,
|
||||
decoder_inputs_to_override: Optional[SingletonInputs] = None,
|
||||
decoder_inputs_to_override: Optional[Union[TokenInputs,
|
||||
MultiModalInputs]] = None,
|
||||
) -> tuple[SingletonInputs, SingletonInputs]:
|
||||
"""
|
||||
For encoder/decoder models only:
|
||||
@@ -540,6 +580,8 @@ class InputPreprocessor:
|
||||
# For multimodal model, override decoder prompt from processor
|
||||
# with explicit decoder prompt.
|
||||
if self.model_config.is_multimodal_model:
|
||||
assert decoder_inputs is None or not is_embeds_inputs(
|
||||
decoder_inputs)
|
||||
encoder_inputs, decoder_inputs = (
|
||||
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||
encoder_inputs, decoder_inputs))
|
||||
@@ -555,9 +597,12 @@ class InputPreprocessor:
|
||||
inputs))
|
||||
else:
|
||||
encoder_inputs = inputs
|
||||
|
||||
decoder_inputs = None
|
||||
|
||||
# Mypy does not do type inference well with TypedDicts with Literal
|
||||
# values.
|
||||
assert not is_embeds_inputs(encoder_inputs)
|
||||
assert decoder_inputs is None or not is_embeds_inputs(decoder_inputs)
|
||||
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
|
||||
|
||||
async def _process_encoder_decoder_prompt_async(
|
||||
@@ -590,6 +635,8 @@ class InputPreprocessor:
|
||||
# For multimodal model, override decoder prompt from processor
|
||||
# with explicit decoder prompt.
|
||||
if self.model_config.is_multimodal_model:
|
||||
assert decoder_inputs is None or not is_embeds_inputs(
|
||||
decoder_inputs)
|
||||
encoder_inputs, decoder_inputs = (
|
||||
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||
encoder_inputs, decoder_inputs))
|
||||
@@ -605,9 +652,12 @@ class InputPreprocessor:
|
||||
inputs))
|
||||
else:
|
||||
encoder_inputs = inputs
|
||||
|
||||
decoder_inputs = None
|
||||
|
||||
# Mypy does not do type inference well with TypedDicts with Literal
|
||||
# values.
|
||||
assert not is_embeds_inputs(encoder_inputs)
|
||||
assert decoder_inputs is None or not is_embeds_inputs(decoder_inputs)
|
||||
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
|
||||
|
||||
def _build_decoder_only_llm_inputs(
|
||||
@@ -617,10 +667,15 @@ class InputPreprocessor:
|
||||
) -> DecoderOnlyInputs:
|
||||
if (prompt_inputs["type"] == "token"
|
||||
or prompt_inputs["type"] == "multimodal"):
|
||||
# Mypy does not do type inference well with typedicts and Literal
|
||||
# values
|
||||
assert not is_embeds_inputs(prompt_inputs)
|
||||
prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter(
|
||||
prompt_inputs["prompt_token_ids"],
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
elif (prompt_inputs["type"] == "embeds"):
|
||||
pass
|
||||
else:
|
||||
assert_never(prompt_inputs) # type: ignore[arg-type]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user