[Misc] Replace is_encoder_decoder_inputs with split_enc_dec_inputs (#15620)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-03-28 01:36:32 +08:00
committed by GitHub
parent 07bf813fb5
commit 247181536f
8 changed files with 49 additions and 54 deletions

View File

@@ -7,7 +7,7 @@ from typing import Optional, Union
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType, SingletonInputsAdapter)
from vllm.inputs.parse import is_encoder_decoder_inputs
from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.lora.request import LoRARequest
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
@@ -209,14 +209,8 @@ class Processor:
self._validate_model_inputs(processed_inputs, lora_request)
if is_encoder_decoder_inputs(processed_inputs):
decoder_inputs = SingletonInputsAdapter(
processed_inputs["decoder"])
encoder_inputs = SingletonInputsAdapter(
processed_inputs["encoder"])
else:
decoder_inputs = SingletonInputsAdapter(processed_inputs)
encoder_inputs = None
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
decoder_inputs = SingletonInputsAdapter(decoder_inputs)
# TODO: Impl encoder-decoder
if encoder_inputs is not None:
@@ -301,15 +295,16 @@ class Processor:
def _validate_model_inputs(self,
inputs: ProcessorInputs,
lora_request: Optional[LoRARequest] = None):
if is_encoder_decoder_inputs(inputs):
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
prompt_inputs = inputs["decoder" if self.model_config.
is_multimodal_model else "encoder"]
else:
prompt_inputs = inputs
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
if self.model_config.is_multimodal_model:
prompt_inputs = decoder_inputs
else:
prompt_inputs = encoder_inputs or decoder_inputs
prompt_ids = prompt_inputs["prompt_token_ids"]
if prompt_ids is None or len(prompt_ids) == 0:
raise ValueError("Prompt cannot be empty")