[Core] Make encoder-decoder inputs a nested structure to be more composable (#9604)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -10,7 +10,7 @@ from typing import Sequence as GenericSequence
|
||||
from typing import Set, Type, Union, cast, overload
|
||||
|
||||
import torch
|
||||
from typing_extensions import TypeIs, TypeVar
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
|
||||
@@ -29,9 +29,9 @@ from vllm.entrypoints.openai.logits_processors import (
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.executor.gpu_executor import GPUExecutor
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
|
||||
EncoderDecoderInputs, InputRegistry, PromptType,
|
||||
TokensPrompt)
|
||||
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
|
||||
PromptType)
|
||||
from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt
|
||||
from vllm.inputs.preprocess import InputPreprocessor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.logits_process import get_bad_words_logits_processors
|
||||
@@ -638,7 +638,7 @@ class LLMEngine:
|
||||
def _add_processed_request(
|
||||
self,
|
||||
request_id: str,
|
||||
processed_inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs],
|
||||
processed_inputs: ProcessorInputs,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: float,
|
||||
lora_request: Optional[LoRARequest],
|
||||
@@ -669,18 +669,19 @@ class LLMEngine:
|
||||
seq_id = next(self.seq_counter)
|
||||
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
|
||||
|
||||
seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
|
||||
if is_encoder_decoder_inputs(processed_inputs):
|
||||
decoder_inputs = processed_inputs["decoder"]
|
||||
encoder_inputs = processed_inputs["encoder"]
|
||||
else:
|
||||
decoder_inputs = processed_inputs
|
||||
encoder_inputs = None
|
||||
|
||||
seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
|
||||
lora_request, prompt_adapter_request)
|
||||
|
||||
encoder_seq = None
|
||||
if 'encoder_prompt_token_ids' in processed_inputs:
|
||||
encoder_seq = Sequence(seq_id,
|
||||
processed_inputs,
|
||||
block_size,
|
||||
eos_token_id,
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
from_decoder_prompt=False)
|
||||
encoder_seq = (None if encoder_inputs is None else Sequence(
|
||||
seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
|
||||
prompt_adapter_request))
|
||||
|
||||
# Create a SequenceGroup based on SamplingParams or PoolingParams
|
||||
if isinstance(params, SamplingParams):
|
||||
@@ -874,7 +875,7 @@ class LLMEngine:
|
||||
# This needs to happen before multimodal input pre-processing, which
|
||||
# may add dummy <image> tokens that aren't part of the tokenizer's
|
||||
# vocabulary.
|
||||
if self._is_token_prompt(prompt):
|
||||
if is_token_prompt(prompt):
|
||||
prompt_ids = prompt["prompt_token_ids"]
|
||||
if len(prompt_ids) == 0:
|
||||
# Empty prompt check is handled later
|
||||
@@ -884,10 +885,6 @@ class LLMEngine:
|
||||
raise ValueError(
|
||||
"Token id {} is out of vocabulary".format(max_input_id))
|
||||
|
||||
@staticmethod
|
||||
def _is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]:
|
||||
return isinstance(prompt, dict) and "prompt_token_ids" in prompt
|
||||
|
||||
def _create_sequence_group_with_sampling(
|
||||
self,
|
||||
request_id: str,
|
||||
@@ -1978,17 +1975,17 @@ class LLMEngine:
|
||||
def is_encoder_decoder_model(self):
|
||||
return self.input_preprocessor.is_encoder_decoder_model()
|
||||
|
||||
def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs,
|
||||
EncoderDecoderInputs],
|
||||
def _validate_model_inputs(self, inputs: ProcessorInputs,
|
||||
lora_request: Optional[LoRARequest]):
|
||||
if self.model_config.is_multimodal_model:
|
||||
if is_encoder_decoder_inputs(inputs):
|
||||
# For encoder-decoder multimodal models, the max_prompt_len
|
||||
# restricts the decoder prompt length
|
||||
prompt_ids = inputs.get("prompt_token_ids")
|
||||
elif self.is_encoder_decoder_model():
|
||||
prompt_ids = inputs.get("encoder_prompt_token_ids")
|
||||
prompt_inputs = inputs["decoder" if self.model_config.
|
||||
is_multimodal_model else "encoder"]
|
||||
else:
|
||||
prompt_ids = inputs.get("prompt_token_ids")
|
||||
prompt_inputs = inputs
|
||||
|
||||
prompt_ids = prompt_inputs.get("prompt_token_ids")
|
||||
|
||||
if prompt_ids is None or len(prompt_ids) == 0:
|
||||
raise ValueError("Prompt cannot be empty")
|
||||
|
||||
Reference in New Issue
Block a user