[CORE] Prompt Embeddings Support for v1 Engine (#24278)

Signed-off-by: Andrew Sansom <andrew@protopia.ai>
Signed-off-by: Andrew Sansom <qthequartermasterman@gmail.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Andrew Sansom
2025-09-18 19:03:09 -05:00
committed by GitHub
parent 9fac6aa30b
commit 9a4600e4dc
20 changed files with 305 additions and 76 deletions

View File

@@ -19,6 +19,7 @@ from vllm.multimodal.utils import argsort_mm_positions
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.utils import length_from_prompt_token_ids_or_embeds
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.structured_output.backend_guidance import (
validate_guidance_grammar)
@@ -390,6 +391,16 @@ class Processor:
self._validate_model_inputs(processed_inputs)
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
# Mypy does not always properly infer the types of some elements of
# discriminated unions of TypedDicts, because of how it handles
# inheritance of TypedDict. If we explicitly extract the items we want
# we can avoid type errors from using `dict.get` later in the method.
prompt_str: Optional[str] = None if decoder_inputs[
"type"] == "embeds" else decoder_inputs.get("prompt")
prompt_token_ids = decoder_inputs[
"prompt_token_ids"] if decoder_inputs["type"] != "embeds" else None
prompt_embeds = decoder_inputs["prompt_embeds"] if decoder_inputs[
"type"] == "embeds" else None
sampling_params = None
pooling_params = None
@@ -398,9 +409,10 @@ class Processor:
sampling_params = params.clone()
# If unset max tokens, then generate up to the max_model_len.
if sampling_params.max_tokens is None:
sampling_params.max_tokens = (
self.model_config.max_model_len -
len(decoder_inputs["prompt_token_ids"]))
seq_len = length_from_prompt_token_ids_or_embeds(
prompt_token_ids, prompt_embeds)
sampling_params.max_tokens = \
self.model_config.max_model_len - seq_len
sampling_params.update_from_generation_config(
self.generation_config_fields, eos_token_id)
if self.tokenizer is not None:
@@ -430,9 +442,10 @@ class Processor:
identifier=decoder_mm_hashes[modality][idx],
mm_position=decoder_mm_positions[modality][idx]))
return decoder_inputs.get("prompt"), EngineCoreRequest(
return prompt_str, EngineCoreRequest(
request_id=request_id,
prompt_token_ids=decoder_inputs["prompt_token_ids"],
prompt_token_ids=prompt_token_ids,
prompt_embeds=prompt_embeds,
mm_features=mm_features,
sampling_params=sampling_params,
pooling_params=pooling_params,
@@ -461,10 +474,17 @@ class Processor:
):
model_config = self.model_config
prompt_ids = prompt_inputs["prompt_token_ids"]
prompt_ids = None if prompt_inputs[
"type"] == "embeds" else prompt_inputs["prompt_token_ids"]
prompt_embeds = prompt_inputs["prompt_embeds"] if prompt_inputs[
"type"] == "embeds" else None
prompt_len = length_from_prompt_token_ids_or_embeds(
prompt_ids, prompt_embeds)
if not prompt_ids:
if prompt_type == "encoder" and model_config.is_multimodal_model:
pass # Mllama may have empty encoder inputs for text-only data
elif prompt_inputs["type"] == "embeds":
pass # Prompt embeds should not have prompt_ids.
else:
raise ValueError(f"The {prompt_type} prompt cannot be empty")
@@ -472,7 +492,7 @@ class Processor:
tokenizer = None
else:
tokenizer = self.tokenizer
max_input_id = max(prompt_ids, default=0)
max_input_id = max(prompt_ids or [], default=0)
# NOTE: tokenizer.max_token_id is the tokenizers vocab size while
# self.model_config.get_vocab_size() is the models vocab size.
@@ -490,7 +510,7 @@ class Processor:
f"Token id {max_input_id} is out of vocabulary")
max_prompt_len = self.model_config.max_model_len
if len(prompt_ids) > max_prompt_len:
if prompt_len > max_prompt_len:
if prompt_type == "encoder" and model_config.is_multimodal_model:
mm_registry = self.input_preprocessor.mm_registry
mm_processor = mm_registry.create_processor(
@@ -514,7 +534,7 @@ class Processor:
"number of text tokens.")
raise ValueError(
f"The {prompt_type} prompt (length {len(prompt_ids)}) is "
f"The {prompt_type} prompt (length {prompt_len}) is "
f"longer than the maximum model length of {max_prompt_len}. "
f"{suggestion}")