[Misc] Replace is_encoder_decoder_inputs with split_enc_dec_inputs (#15620)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,15 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Literal, TypedDict, Union, cast, overload
|
||||
from typing import Literal, Optional, TypedDict, Union, cast, overload
|
||||
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from .data import (EncoderDecoderInputs, ExplicitEncoderDecoderPrompt,
|
||||
ProcessorInputs, PromptType, SingletonPrompt, TextPrompt,
|
||||
TokensPrompt)
|
||||
from .data import (ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
|
||||
SingletonInputs, SingletonPrompt, TextPrompt, TokensPrompt)
|
||||
|
||||
|
||||
class ParsedText(TypedDict):
|
||||
@@ -110,6 +108,14 @@ def is_explicit_encoder_decoder_prompt(
|
||||
return isinstance(prompt, dict) and "encoder_prompt" in prompt
|
||||
|
||||
|
||||
def is_encoder_decoder_inputs(
|
||||
inputs: ProcessorInputs) -> TypeIs[EncoderDecoderInputs]:
|
||||
return "encoder" in inputs and "decoder" in inputs
|
||||
def split_enc_dec_inputs(
|
||||
inputs: ProcessorInputs,
|
||||
) -> tuple[Optional[SingletonInputs], SingletonInputs]:
|
||||
if "encoder" in inputs and "decoder" in inputs:
|
||||
# NOTE: This passes pyright but not mypy
|
||||
return (
|
||||
inputs["encoder"], # type: ignore[typeddict-item]
|
||||
inputs["decoder"], # type: ignore[typeddict-item]
|
||||
)
|
||||
|
||||
return None, inputs
|
||||
|
||||
@@ -19,7 +19,7 @@ from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
|
||||
resolve_mm_processor_kwargs)
|
||||
|
||||
from .data import ProcessorInputs, SingletonInputs
|
||||
from .parse import is_encoder_decoder_inputs
|
||||
from .parse import split_enc_dec_inputs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig
|
||||
@@ -462,13 +462,11 @@ class InputRegistry:
|
||||
**mm_processor_kwargs,
|
||||
)
|
||||
|
||||
if is_encoder_decoder_inputs(processed_inputs):
|
||||
self._ensure_mm_kwargs(processed_inputs["encoder"],
|
||||
mm_processor_kwargs)
|
||||
self._ensure_mm_kwargs(processed_inputs["decoder"],
|
||||
mm_processor_kwargs)
|
||||
else:
|
||||
self._ensure_mm_kwargs(processed_inputs, mm_processor_kwargs)
|
||||
encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
|
||||
if encoder_inputs is not None:
|
||||
self._ensure_mm_kwargs(encoder_inputs, mm_processor_kwargs)
|
||||
if decoder_inputs is not None:
|
||||
self._ensure_mm_kwargs(decoder_inputs, mm_processor_kwargs)
|
||||
|
||||
return processed_inputs
|
||||
|
||||
|
||||
Reference in New Issue
Block a user