[Misc] Replace is_encoder_decoder_inputs with split_enc_dec_inputs (#15620)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-03-28 01:36:32 +08:00
committed by GitHub
parent 07bf813fb5
commit 247181536f
8 changed files with 49 additions and 54 deletions

View File

@@ -1,15 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Sequence
from typing import Literal, TypedDict, Union, cast, overload
from typing import Literal, Optional, TypedDict, Union, cast, overload
from typing_extensions import TypeIs
from vllm.utils import is_list_of
from .data import (EncoderDecoderInputs, ExplicitEncoderDecoderPrompt,
ProcessorInputs, PromptType, SingletonPrompt, TextPrompt,
TokensPrompt)
from .data import (ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
SingletonInputs, SingletonPrompt, TextPrompt, TokensPrompt)
class ParsedText(TypedDict):
@@ -110,6 +108,14 @@ def is_explicit_encoder_decoder_prompt(
return isinstance(prompt, dict) and "encoder_prompt" in prompt
def is_encoder_decoder_inputs(
inputs: ProcessorInputs) -> TypeIs[EncoderDecoderInputs]:
return "encoder" in inputs and "decoder" in inputs
def split_enc_dec_inputs(
inputs: ProcessorInputs,
) -> tuple[Optional[SingletonInputs], SingletonInputs]:
if "encoder" in inputs and "decoder" in inputs:
# NOTE: This passes pyright but not mypy
return (
inputs["encoder"], # type: ignore[typeddict-item]
inputs["decoder"], # type: ignore[typeddict-item]
)
return None, inputs

View File

@@ -19,7 +19,7 @@ from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
resolve_mm_processor_kwargs)
from .data import ProcessorInputs, SingletonInputs
from .parse import is_encoder_decoder_inputs
from .parse import split_enc_dec_inputs
if TYPE_CHECKING:
from vllm.config import ModelConfig
@@ -462,13 +462,11 @@ class InputRegistry:
**mm_processor_kwargs,
)
if is_encoder_decoder_inputs(processed_inputs):
self._ensure_mm_kwargs(processed_inputs["encoder"],
mm_processor_kwargs)
self._ensure_mm_kwargs(processed_inputs["decoder"],
mm_processor_kwargs)
else:
self._ensure_mm_kwargs(processed_inputs, mm_processor_kwargs)
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
if encoder_inputs is not None:
self._ensure_mm_kwargs(encoder_inputs, mm_processor_kwargs)
if decoder_inputs is not None:
self._ensure_mm_kwargs(decoder_inputs, mm_processor_kwargs)
return processed_inputs