[Core] [Bugfix] Add Input Embeddings (#15428)

Signed-off-by: Andrew Sansom <andrew@protopia.ai>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: 临景 <linjing.yx@alibaba-inc.com>
Co-authored-by: Bryce1010 <bryceyx@gmail.com>
Co-authored-by: Nan2018 <nan@protopia.ai>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Andrew Sansom
2025-05-02 03:06:39 -05:00
committed by GitHub
parent 9e2de9b9e9
commit cc2a77d7f1
22 changed files with 691 additions and 113 deletions

View File

@@ -6,6 +6,7 @@ from typing import Any, Optional, Union, cast
from typing_extensions import assert_never
from vllm import envs
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@@ -15,9 +16,11 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs,
PromptType, SingletonInputs, SingletonPrompt, token_inputs)
from .parse import (ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt,
from .data import (DecoderOnlyInputs, EmbedsInputs, EncoderDecoderInputs,
ProcessorInputs, PromptType, SingletonInputs,
SingletonPrompt, TokenInputs, embeds_inputs, token_inputs)
from .parse import (ParsedEmbedsPrompt, ParsedStrPrompt, ParsedTextPrompt,
ParsedTokensPrompt, is_embeds_inputs,
is_explicit_encoder_decoder_prompt, parse_singleton_prompt)
logger = init_logger(__name__)
@@ -328,6 +331,10 @@ class InputPreprocessor:
* :class:`SingletonInputs` instance
"""
parsed = parse_singleton_prompt(prompt)
if parsed["type"] == "embeds":
return self._process_prompt_embeds(parsed)
prompt_text, prompt_token_ids, token_type_ids, cache_salt = \
self._get_prompt_data(parsed)
@@ -359,6 +366,8 @@ class InputPreprocessor:
cache_salt=cache_salt,
)
assert_never(parsed)
async def _prompt_to_llm_inputs_async(
self,
prompt: SingletonPrompt,
@@ -369,6 +378,9 @@ class InputPreprocessor:
"""Async version of :meth:`_extract_prompt_components`."""
parsed = parse_singleton_prompt(prompt)
if parsed["type"] == "embeds":
return self._process_prompt_embeds(parsed)
prompt_text, prompt_token_ids, token_type_ids, cache_salt = \
self._get_prompt_data(parsed)
@@ -399,10 +411,34 @@ class InputPreprocessor:
cache_salt=cache_salt,
)
def _process_prompt_embeds(self,
parsed: ParsedEmbedsPrompt) -> EmbedsInputs:
if envs.VLLM_USE_V1:
raise ValueError("prompt_embeds is only available in V0.")
prompt_embeds_content = parsed["content"]
prompt_embeds = prompt_embeds_content["prompt_embeds"]
# prompt_embeds must be (seq_len, hidden_size), but if the user
# passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
# we can unambiguously process the intent by squeezing the batch
# dimension.
if prompt_embeds.ndim == 3 and prompt_embeds.shape[0] == 1:
prompt_embeds = prompt_embeds.squeeze(dim=0)
if prompt_embeds.ndim != 2:
raise ValueError(
"prompt_embeds must be of shape (seq_len, hidden_size).")
return embeds_inputs(prompt_embeds=prompt_embeds)
assert_never(parsed)
def _build_enc_dec_llm_inputs(
self,
encoder_inputs: SingletonInputs,
decoder_inputs: Optional[SingletonInputs],
encoder_inputs: Union[TokenInputs, MultiModalInputs],
decoder_inputs: Optional[Union[TokenInputs, MultiModalInputs]],
) -> EncoderDecoderInputs:
if (encoder_inputs["type"] == "token"
or encoder_inputs["type"] == "multimodal"):
@@ -410,6 +446,9 @@ class InputPreprocessor:
else:
assert_never(encoder_inputs) # type: ignore[arg-type]
# Mypy does not correctly infer that EmbedsInputs is impossible
assert "prompt_token_ids" in encoder_inputs
if decoder_inputs is None:
if self.model_config.hf_config.model_type == "whisper":
# For Whisper models, the text prompt should go to the decoder.
@@ -441,7 +480,8 @@ class InputPreprocessor:
def _separate_enc_dec_inputs_from_mm_processor_outputs(
self,
inputs: SingletonInputs,
decoder_inputs_to_override: Optional[SingletonInputs] = None,
decoder_inputs_to_override: Optional[Union[TokenInputs,
MultiModalInputs]] = None,
) -> tuple[SingletonInputs, SingletonInputs]:
"""
For encoder/decoder models only:
@@ -540,6 +580,8 @@ class InputPreprocessor:
# For multimodal model, override decoder prompt from processor
# with explicit decoder prompt.
if self.model_config.is_multimodal_model:
assert decoder_inputs is None or not is_embeds_inputs(
decoder_inputs)
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
encoder_inputs, decoder_inputs))
@@ -555,9 +597,12 @@ class InputPreprocessor:
inputs))
else:
encoder_inputs = inputs
decoder_inputs = None
# Mypy does not do type inference well with TypedDicts with Literal
# values.
assert not is_embeds_inputs(encoder_inputs)
assert decoder_inputs is None or not is_embeds_inputs(decoder_inputs)
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
async def _process_encoder_decoder_prompt_async(
@@ -590,6 +635,8 @@ class InputPreprocessor:
# For multimodal model, override decoder prompt from processor
# with explicit decoder prompt.
if self.model_config.is_multimodal_model:
assert decoder_inputs is None or not is_embeds_inputs(
decoder_inputs)
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
encoder_inputs, decoder_inputs))
@@ -605,9 +652,12 @@ class InputPreprocessor:
inputs))
else:
encoder_inputs = inputs
decoder_inputs = None
# Mypy does not do type inference well with TypedDicts with Literal
# values.
assert not is_embeds_inputs(encoder_inputs)
assert decoder_inputs is None or not is_embeds_inputs(decoder_inputs)
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
def _build_decoder_only_llm_inputs(
@@ -617,10 +667,15 @@ class InputPreprocessor:
) -> DecoderOnlyInputs:
if (prompt_inputs["type"] == "token"
or prompt_inputs["type"] == "multimodal"):
# Mypy does not do type inference well with typedicts and Literal
# values
assert not is_embeds_inputs(prompt_inputs)
prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter(
prompt_inputs["prompt_token_ids"],
prompt_adapter_request=prompt_adapter_request,
)
elif (prompt_inputs["type"] == "embeds"):
pass
else:
assert_never(prompt_inputs) # type: ignore[arg-type]