[VLM] Implement merged multimodal processor for Mllama (#11427)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import asyncio
|
||||
from typing import List, Mapping, Optional, Union
|
||||
from typing import List, Mapping, Optional, Tuple, Union, cast
|
||||
|
||||
from typing_extensions import assert_never
|
||||
|
||||
@@ -9,7 +9,8 @@ from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
|
||||
from vllm.multimodal.inputs import MultiModalDataDict, MultiModalInputs
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
|
||||
MultiModalInputs)
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
||||
|
||||
@@ -495,6 +496,51 @@ class InputPreprocessor:
|
||||
decoder=decoder_inputs,
|
||||
)
|
||||
|
||||
def _separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||
self,
|
||||
inputs: SingletonInputs,
|
||||
decoder_inputs_to_override: Optional[SingletonInputs] = None,
|
||||
) -> Tuple[SingletonInputs, SingletonInputs]:
|
||||
"""
|
||||
For encoder/decoder models only:
|
||||
Separate Encoder/Decoder inputs from a MultiModalEncDecInputs
|
||||
"""
|
||||
encoder_inputs: SingletonInputs
|
||||
decoder_inputs: SingletonInputs
|
||||
if inputs["type"] == "multimodal":
|
||||
# Multimodal data inputs
|
||||
assert ("encoder_prompt" in inputs
|
||||
and "encoder_prompt_token_ids" in inputs)
|
||||
inputs = cast(MultiModalEncDecInputs, inputs)
|
||||
encoder_inputs = token_inputs(
|
||||
prompt=inputs["encoder_prompt"],
|
||||
prompt_token_ids=inputs["encoder_prompt_token_ids"],
|
||||
)
|
||||
if decoder_inputs_to_override is not None:
|
||||
decoder_inputs = MultiModalInputs(
|
||||
type="multimodal",
|
||||
prompt=decoder_inputs_to_override.get("prompt", ""),
|
||||
prompt_token_ids=decoder_inputs_to_override[
|
||||
"prompt_token_ids"],
|
||||
mm_kwargs=inputs["mm_kwargs"],
|
||||
mm_placeholders=inputs["mm_placeholders"],
|
||||
)
|
||||
else:
|
||||
decoder_inputs = MultiModalInputs(
|
||||
type="multimodal",
|
||||
prompt=inputs["prompt"],
|
||||
prompt_token_ids=inputs["prompt_token_ids"],
|
||||
mm_kwargs=inputs["mm_kwargs"],
|
||||
mm_placeholders=inputs["mm_placeholders"],
|
||||
)
|
||||
elif inputs["type"] == "token":
|
||||
# Text-only inputs
|
||||
encoder_inputs = token_inputs(prompt="", prompt_token_ids=[])
|
||||
decoder_inputs = decoder_inputs_to_override or inputs
|
||||
else:
|
||||
assert_never(inputs) # type: ignore[arg-type]
|
||||
return encoder_inputs, decoder_inputs
|
||||
|
||||
def _process_encoder_decoder_prompt(
|
||||
self,
|
||||
prompt: PromptType,
|
||||
@@ -539,7 +585,6 @@ class InputPreprocessor:
|
||||
prompt["encoder_prompt"],
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
if (decoder_input := prompt["decoder_prompt"]) is None:
|
||||
decoder_inputs = None
|
||||
else:
|
||||
@@ -547,13 +592,28 @@ class InputPreprocessor:
|
||||
decoder_input,
|
||||
request_id=request_id,
|
||||
)
|
||||
# For multimodal model, override decoder prompt from processor
|
||||
# with explicit decoder prompt.
|
||||
if self.model_config.is_multimodal_model and (
|
||||
self._can_process_multimodal()):
|
||||
encoder_inputs, decoder_inputs = (
|
||||
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||
encoder_inputs, decoder_inputs))
|
||||
else:
|
||||
encoder_inputs = self._prompt_to_llm_inputs(
|
||||
inputs = self._prompt_to_llm_inputs(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
)
|
||||
if self.model_config.is_multimodal_model and (
|
||||
self._can_process_multimodal()):
|
||||
# Encoder-Decoder Multimodal model
|
||||
encoder_inputs, decoder_inputs = (
|
||||
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||
inputs))
|
||||
else:
|
||||
encoder_inputs = inputs
|
||||
|
||||
decoder_inputs = None
|
||||
decoder_inputs = None
|
||||
|
||||
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
|
||||
|
||||
@@ -583,13 +643,29 @@ class InputPreprocessor:
|
||||
|
||||
encoder_inputs, decoder_inputs = await asyncio.gather(
|
||||
encoder_task, decoder_task)
|
||||
|
||||
# For multimodal model, override decoder prompt from processor
|
||||
# with explicit decoder prompt.
|
||||
if self.model_config.is_multimodal_model and (
|
||||
self._can_process_multimodal()):
|
||||
encoder_inputs, decoder_inputs = (
|
||||
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||
encoder_inputs, decoder_inputs))
|
||||
else:
|
||||
encoder_inputs = await self._prompt_to_llm_inputs_async(
|
||||
inputs = await self._prompt_to_llm_inputs_async(
|
||||
prompt,
|
||||
request_id=request_id,
|
||||
)
|
||||
if self.model_config.is_multimodal_model and (
|
||||
self._can_process_multimodal()):
|
||||
# Encoder-Decoder Multimodal model
|
||||
encoder_inputs, decoder_inputs = (
|
||||
self._separate_enc_dec_inputs_from_mm_processor_outputs(
|
||||
inputs))
|
||||
else:
|
||||
encoder_inputs = inputs
|
||||
|
||||
decoder_inputs = None
|
||||
decoder_inputs = None
|
||||
|
||||
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user