[Misc] Replace is_encoder_decoder_inputs with split_enc_dec_inputs (#15620)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-03-28 01:36:32 +08:00
committed by GitHub
parent 07bf813fb5
commit 247181536f
8 changed files with 49 additions and 54 deletions

View File

@@ -30,8 +30,8 @@ from vllm.entrypoints.openai.logits_processors import (
get_logits_processors as get_openai_logits_processors)
from vllm.executor.executor_base import ExecutorBase
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType, SingletonInputsAdapter)
from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt
PromptType)
from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.logits_process import get_bad_words_logits_processors
@@ -609,12 +609,7 @@ class LLMEngine:
seq_id = next(self.seq_counter)
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
if is_encoder_decoder_inputs(processed_inputs):
decoder_inputs = processed_inputs["decoder"]
encoder_inputs = processed_inputs["encoder"]
else:
decoder_inputs = processed_inputs
encoder_inputs = None
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
lora_request, prompt_adapter_request)
@@ -2031,15 +2026,16 @@ class LLMEngine:
def _validate_model_inputs(self, inputs: ProcessorInputs,
lora_request: Optional[LoRARequest]):
if is_encoder_decoder_inputs(inputs):
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
prompt_inputs = inputs["decoder" if self.model_config.
is_multimodal_model else "encoder"]
else:
prompt_inputs = inputs
encoder_inputs, decoder_inputs = split_enc_dec_inputs(inputs)
prompt_ids = SingletonInputsAdapter(prompt_inputs).prompt_token_ids
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
if self.model_config.is_multimodal_model:
prompt_inputs = decoder_inputs
else:
prompt_inputs = encoder_inputs or decoder_inputs
prompt_ids = prompt_inputs["prompt_token_ids"]
if prompt_ids is None or len(prompt_ids) == 0:
raise ValueError("Prompt cannot be empty")