[Core] [Bugfix] Add Input Embeddings (#15428)
Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Co-authored-by: 临景 <linjing.yx@alibaba-inc.com> Co-authored-by: Bryce1010 <bryceyx@gmail.com> Co-authored-by: Nan2018 <nan@protopia.ai> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -6,6 +6,7 @@ from typing import Any, Optional, Union, cast
|
||||
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
@@ -15,9 +16,11 @@ from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||
|
||||
from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs,
|
||||
PromptType, SingletonInputs, SingletonPrompt, token_inputs)
|
||||
from .parse import (ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt,
|
||||
from .data import (DecoderOnlyInputs, EmbedsInputs, EncoderDecoderInputs,
|
||||
ProcessorInputs, PromptType, SingletonInputs,
|
||||
SingletonPrompt, TokenInputs, embeds_inputs, token_inputs)
|
||||
from .parse import (ParsedEmbedsPrompt, ParsedStrPrompt, ParsedTextPrompt,
|
||||
ParsedTokensPrompt, is_embeds_inputs,
|
||||
is_explicit_encoder_decoder_prompt, parse_singleton_prompt)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -328,6 +331,10 @@ class InputPreprocessor:
|
||||
* :class:`SingletonInputs` instance
|
||||
"""
|
||||
parsed = parse_singleton_prompt(prompt)
|
||||
|
||||
if parsed["type"] == "embeds":
|
||||
return self._process_prompt_embeds(parsed)
|
||||
|
||||
prompt_text, prompt_token_ids, token_type_ids, cache_salt = \
|
||||
self._get_prompt_data(parsed)
|
||||
|
||||
@@ -359,6 +366,8 @@ class InputPreprocessor:
|
||||
cache_salt=cache_salt,
|
||||
)
|
||||
|
||||
assert_never(parsed)
|
||||
|
||||
async def _prompt_to_llm_inputs_async(
|
||||
self,
|
||||
prompt: SingletonPrompt,
|
||||
@@ -369,6 +378,9 @@ class InputPreprocessor:
|
||||
"""Async version of :meth:`_extract_prompt_components`."""
|
||||
parsed = parse_singleton_prompt(prompt)
|
||||
|
||||
if parsed["type"] == "embeds":
|
||||
return self._process_prompt_embeds(parsed)
|
||||
|
||||
prompt_text, prompt_token_ids, token_type_ids, cache_salt = \
|
||||
self._get_prompt_data(parsed)
|
||||
|
||||
@@ -399,10 +411,34 @@ class InputPreprocessor:
|
||||
cache_salt=cache_salt,
|
||||
)
|
||||
|
||||
def _process_prompt_embeds(self,
|
||||
parsed: ParsedEmbedsPrompt) -> EmbedsInputs:
|
||||
if envs.VLLM_USE_V1:
|
||||
raise ValueError("prompt_embeds is only available in V0.")
|
||||
|
||||
prompt_embeds_content = parsed["content"]
|
||||
|
||||
prompt_embeds = prompt_embeds_content["prompt_embeds"]
|
||||
|
||||
# prompt_embeds must be (seq_len, hidden_size), but if the user
|
||||
# passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
|
||||
# we can unambiguously process the intent by squeezing the batch
|
||||
# dimension.
|
||||
if prompt_embeds.ndim == 3 and prompt_embeds.shape[0] == 1:
|
||||
prompt_embeds = prompt_embeds.squeeze(dim=0)
|
||||
|
||||
if prompt_embeds.ndim != 2:
|
||||
raise ValueError(
|
||||
"prompt_embeds must be of shape (seq_len, hidden_size).")
|
||||
|
||||
return embeds_inputs(prompt_embeds=prompt_embeds)
|
||||
|
||||
assert_never(parsed)
|
||||
|
||||
def _build_enc_dec_llm_inputs(
|
||||
self,
|
||||
encoder_inputs: SingletonInputs,
|
||||
decoder_inputs: Optional[SingletonInputs],
|
||||
encoder_inputs: Union[TokenInputs, MultiModalInputs],
|
||||
decoder_inputs: Optional[Union[TokenInputs, MultiModalInputs]],
|
||||
) -> EncoderDecoderInputs:
|
||||
if (encoder_inputs["type"] == "token"
|
||||
or encoder_inputs["type"] == "multimodal"):
|
||||
@@ -410,6 +446,9 @@ class InputPreprocessor:
|
||||
else:
|
||||
assert_never(encoder_inputs) # type: ignore[arg-type]
|
||||
|
||||
# Mypy does not correctly infer that EmbedsInputs is impossible
|
||||
assert "prompt_token_ids" in encoder_inputs
|
||||
|
||||
if decoder_inputs is None:
|
||||
if self.model_config.hf_config.model_type == "whisper":
|
||||
# For Whisper models, the text prompt should go to the decoder.
|
||||
@@ -441,7 +480,8 @@ class InputPreprocessor:
|
||||
def _separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||
self,
|
||||
inputs: SingletonInputs,
|
||||
decoder_inputs_to_override: Optional[SingletonInputs] = None,
|
||||
decoder_inputs_to_override: Optional[Union[TokenInputs,
|
||||
MultiModalInputs]] = None,
|
||||
) -> tuple[SingletonInputs, SingletonInputs]:
|
||||
"""
|
||||
For encoder/decoder models only:
|
||||
@@ -540,6 +580,8 @@ class InputPreprocessor:
|
||||
# For multimodal model, override decoder prompt from processor
|
||||
# with explicit decoder prompt.
|
||||
if self.model_config.is_multimodal_model:
|
||||
assert decoder_inputs is None or not is_embeds_inputs(
|
||||
decoder_inputs)
|
||||
encoder_inputs, decoder_inputs = (
|
||||
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||
encoder_inputs, decoder_inputs))
|
||||
@@ -555,9 +597,12 @@ class InputPreprocessor:
|
||||
inputs))
|
||||
else:
|
||||
encoder_inputs = inputs
|
||||
|
||||
decoder_inputs = None
|
||||
|
||||
# Mypy does not do type inference well with TypedDicts with Literal
|
||||
# values.
|
||||
assert not is_embeds_inputs(encoder_inputs)
|
||||
assert decoder_inputs is None or not is_embeds_inputs(decoder_inputs)
|
||||
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
|
||||
|
||||
async def _process_encoder_decoder_prompt_async(
|
||||
@@ -590,6 +635,8 @@ class InputPreprocessor:
|
||||
# For multimodal model, override decoder prompt from processor
|
||||
# with explicit decoder prompt.
|
||||
if self.model_config.is_multimodal_model:
|
||||
assert decoder_inputs is None or not is_embeds_inputs(
|
||||
decoder_inputs)
|
||||
encoder_inputs, decoder_inputs = (
|
||||
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||
encoder_inputs, decoder_inputs))
|
||||
@@ -605,9 +652,12 @@ class InputPreprocessor:
|
||||
inputs))
|
||||
else:
|
||||
encoder_inputs = inputs
|
||||
|
||||
decoder_inputs = None
|
||||
|
||||
# Mypy does not do type inference well with TypedDicts with Literal
|
||||
# values.
|
||||
assert not is_embeds_inputs(encoder_inputs)
|
||||
assert decoder_inputs is None or not is_embeds_inputs(decoder_inputs)
|
||||
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
|
||||
|
||||
def _build_decoder_only_llm_inputs(
|
||||
@@ -617,10 +667,15 @@ class InputPreprocessor:
|
||||
) -> DecoderOnlyInputs:
|
||||
if (prompt_inputs["type"] == "token"
|
||||
or prompt_inputs["type"] == "multimodal"):
|
||||
# Mypy does not do type inference well with typedicts and Literal
|
||||
# values
|
||||
assert not is_embeds_inputs(prompt_inputs)
|
||||
prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter(
|
||||
prompt_inputs["prompt_token_ids"],
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
elif (prompt_inputs["type"] == "embeds"):
|
||||
pass
|
||||
else:
|
||||
assert_never(prompt_inputs) # type: ignore[arg-type]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user