[Core] Make encoder-decoder inputs a nested structure to be more composable (#9604)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-11-05 10:07:31 +08:00
committed by GitHub
parent 04bbf38e05
commit bbc3619dc8
14 changed files with 369 additions and 346 deletions

View File

@@ -10,7 +10,7 @@ from typing import Sequence as GenericSequence
from typing import Set, Type, Union, cast, overload
import torch
from typing_extensions import TypeIs, TypeVar
from typing_extensions import TypeVar
import vllm.envs as envs
from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig,
@@ -29,9 +29,9 @@ from vllm.entrypoints.openai.logits_processors import (
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs,
EncoderDecoderInputs, InputRegistry, PromptType,
TokensPrompt)
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType)
from vllm.inputs.parse import is_encoder_decoder_inputs, is_token_prompt
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.logits_process import get_bad_words_logits_processors
@@ -638,7 +638,7 @@ class LLMEngine:
def _add_processed_request(
self,
request_id: str,
processed_inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs],
processed_inputs: ProcessorInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: float,
lora_request: Optional[LoRARequest],
@@ -669,18 +669,19 @@ class LLMEngine:
seq_id = next(self.seq_counter)
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id,
if is_encoder_decoder_inputs(processed_inputs):
decoder_inputs = processed_inputs["decoder"]
encoder_inputs = processed_inputs["encoder"]
else:
decoder_inputs = processed_inputs
encoder_inputs = None
seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
lora_request, prompt_adapter_request)
encoder_seq = None
if 'encoder_prompt_token_ids' in processed_inputs:
encoder_seq = Sequence(seq_id,
processed_inputs,
block_size,
eos_token_id,
lora_request,
prompt_adapter_request,
from_decoder_prompt=False)
encoder_seq = (None if encoder_inputs is None else Sequence(
seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
prompt_adapter_request))
# Create a SequenceGroup based on SamplingParams or PoolingParams
if isinstance(params, SamplingParams):
@@ -874,7 +875,7 @@ class LLMEngine:
# This needs to happen before multimodal input pre-processing, which
# may add dummy <image> tokens that aren't part of the tokenizer's
# vocabulary.
if self._is_token_prompt(prompt):
if is_token_prompt(prompt):
prompt_ids = prompt["prompt_token_ids"]
if len(prompt_ids) == 0:
# Empty prompt check is handled later
@@ -884,10 +885,6 @@ class LLMEngine:
raise ValueError(
"Token id {} is out of vocabulary".format(max_input_id))
@staticmethod
def _is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]:
return isinstance(prompt, dict) and "prompt_token_ids" in prompt
def _create_sequence_group_with_sampling(
self,
request_id: str,
@@ -1978,17 +1975,17 @@ class LLMEngine:
def is_encoder_decoder_model(self):
return self.input_preprocessor.is_encoder_decoder_model()
def _validate_model_inputs(self, inputs: Union[DecoderOnlyInputs,
EncoderDecoderInputs],
def _validate_model_inputs(self, inputs: ProcessorInputs,
lora_request: Optional[LoRARequest]):
if self.model_config.is_multimodal_model:
if is_encoder_decoder_inputs(inputs):
# For encoder-decoder multimodal models, the max_prompt_len
# restricts the decoder prompt length
prompt_ids = inputs.get("prompt_token_ids")
elif self.is_encoder_decoder_model():
prompt_ids = inputs.get("encoder_prompt_token_ids")
prompt_inputs = inputs["decoder" if self.model_config.
is_multimodal_model else "encoder"]
else:
prompt_ids = inputs.get("prompt_token_ids")
prompt_inputs = inputs
prompt_ids = prompt_inputs.get("prompt_token_ids")
if prompt_ids is None or len(prompt_ids) == 0:
raise ValueError("Prompt cannot be empty")