[CORE] Prompt Embeddings Support for v1 Engine (#24278)
Signed-off-by: Andrew Sansom <andrew@protopia.ai> Signed-off-by: Andrew Sansom <qthequartermasterman@gmail.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
@@ -19,6 +19,7 @@ from vllm.multimodal.utils import argsort_mm_positions
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.structured_output.backend_guidance import (
|
||||
validate_guidance_grammar)
|
||||
@@ -390,6 +391,16 @@ class Processor:
|
||||
self._validate_model_inputs(processed_inputs)
|
||||
|
||||
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
|
||||
# Mypy does not always properly infer the types of some elements of
|
||||
# discriminated unions of TypedDicts, because of how it handles
|
||||
# inheritance of TypedDict. If we explicitly extract the items we want
|
||||
# we can avoid type errors from using `dict.get` later in the method.
|
||||
prompt_str: Optional[str] = None if decoder_inputs[
|
||||
"type"] == "embeds" else decoder_inputs.get("prompt")
|
||||
prompt_token_ids = decoder_inputs[
|
||||
"prompt_token_ids"] if decoder_inputs["type"] != "embeds" else None
|
||||
prompt_embeds = decoder_inputs["prompt_embeds"] if decoder_inputs[
|
||||
"type"] == "embeds" else None
|
||||
|
||||
sampling_params = None
|
||||
pooling_params = None
|
||||
@@ -398,9 +409,10 @@ class Processor:
|
||||
sampling_params = params.clone()
|
||||
# If unset max tokens, then generate up to the max_model_len.
|
||||
if sampling_params.max_tokens is None:
|
||||
sampling_params.max_tokens = (
|
||||
self.model_config.max_model_len -
|
||||
len(decoder_inputs["prompt_token_ids"]))
|
||||
seq_len = length_from_prompt_token_ids_or_embeds(
|
||||
prompt_token_ids, prompt_embeds)
|
||||
sampling_params.max_tokens = \
|
||||
self.model_config.max_model_len - seq_len
|
||||
sampling_params.update_from_generation_config(
|
||||
self.generation_config_fields, eos_token_id)
|
||||
if self.tokenizer is not None:
|
||||
@@ -430,9 +442,10 @@ class Processor:
|
||||
identifier=decoder_mm_hashes[modality][idx],
|
||||
mm_position=decoder_mm_positions[modality][idx]))
|
||||
|
||||
return decoder_inputs.get("prompt"), EngineCoreRequest(
|
||||
return prompt_str, EngineCoreRequest(
|
||||
request_id=request_id,
|
||||
prompt_token_ids=decoder_inputs["prompt_token_ids"],
|
||||
prompt_token_ids=prompt_token_ids,
|
||||
prompt_embeds=prompt_embeds,
|
||||
mm_features=mm_features,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=pooling_params,
|
||||
@@ -461,10 +474,17 @@ class Processor:
|
||||
):
|
||||
model_config = self.model_config
|
||||
|
||||
prompt_ids = prompt_inputs["prompt_token_ids"]
|
||||
prompt_ids = None if prompt_inputs[
|
||||
"type"] == "embeds" else prompt_inputs["prompt_token_ids"]
|
||||
prompt_embeds = prompt_inputs["prompt_embeds"] if prompt_inputs[
|
||||
"type"] == "embeds" else None
|
||||
prompt_len = length_from_prompt_token_ids_or_embeds(
|
||||
prompt_ids, prompt_embeds)
|
||||
if not prompt_ids:
|
||||
if prompt_type == "encoder" and model_config.is_multimodal_model:
|
||||
pass # Mllama may have empty encoder inputs for text-only data
|
||||
elif prompt_inputs["type"] == "embeds":
|
||||
pass # Prompt embeds should not have prompt_ids.
|
||||
else:
|
||||
raise ValueError(f"The {prompt_type} prompt cannot be empty")
|
||||
|
||||
@@ -472,7 +492,7 @@ class Processor:
|
||||
tokenizer = None
|
||||
else:
|
||||
tokenizer = self.tokenizer
|
||||
max_input_id = max(prompt_ids, default=0)
|
||||
max_input_id = max(prompt_ids or [], default=0)
|
||||
|
||||
# NOTE: tokenizer.max_token_id is the tokenizer’s vocab size while
|
||||
# self.model_config.get_vocab_size() is the model’s vocab size.
|
||||
@@ -490,7 +510,7 @@ class Processor:
|
||||
f"Token id {max_input_id} is out of vocabulary")
|
||||
|
||||
max_prompt_len = self.model_config.max_model_len
|
||||
if len(prompt_ids) > max_prompt_len:
|
||||
if prompt_len > max_prompt_len:
|
||||
if prompt_type == "encoder" and model_config.is_multimodal_model:
|
||||
mm_registry = self.input_preprocessor.mm_registry
|
||||
mm_processor = mm_registry.create_processor(
|
||||
@@ -514,7 +534,7 @@ class Processor:
|
||||
"number of text tokens.")
|
||||
|
||||
raise ValueError(
|
||||
f"The {prompt_type} prompt (length {len(prompt_ids)}) is "
|
||||
f"The {prompt_type} prompt (length {prompt_len}) is "
|
||||
f"longer than the maximum model length of {max_prompt_len}. "
|
||||
f"{suggestion}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user