[Bugfix] Avoid transferring cached multi-modal items from P0 to P1 (#16273)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
@@ -19,6 +19,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
|
||||
from vllm.v1.structured_output.backend_guidance import (
|
||||
validate_guidance_grammar)
|
||||
from vllm.v1.structured_output.utils import (
|
||||
@@ -47,6 +48,8 @@ class Processor:
|
||||
self.tokenizer,
|
||||
mm_registry)
|
||||
|
||||
self.mm_input_cache_client = MirroredProcessingCache(self.model_config)
|
||||
|
||||
# Multi-modal hasher (for images)
|
||||
self.use_hash = (
|
||||
not self.model_config.disable_mm_preprocessor_cache) or \
|
||||
@@ -231,7 +234,7 @@ class Processor:
|
||||
self.tokenizer.get_lora_tokenizer(lora_request))
|
||||
|
||||
# Multimodal related.
|
||||
sorted_mm_inputs: Optional[list[MultiModalKwargs]] = None
|
||||
sorted_mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] = None
|
||||
sorted_mm_positions: Optional[list[PlaceholderRange]] = None
|
||||
sorted_mm_hashes: Optional[list[str]] = None
|
||||
if decoder_inputs["type"] == "multimodal":
|
||||
@@ -256,20 +259,28 @@ class Processor:
|
||||
# are multiple modalities.
|
||||
unique_modalities = set(sorted_item_modalities)
|
||||
if len(unique_modalities) > 1:
|
||||
sorted_mm_inputs = []
|
||||
orig_sorted_mm_inputs = []
|
||||
used_indices = {modality: 0 for modality in unique_modalities}
|
||||
|
||||
for modality in sorted_item_modalities:
|
||||
items = decoder_mm_inputs.get_items(modality)
|
||||
item = items[used_indices[modality]]
|
||||
sorted_mm_inputs.append(MultiModalKwargs.from_items([item
|
||||
]))
|
||||
|
||||
orig_sorted_mm_inputs.append(
|
||||
MultiModalKwargs.from_items([item]))
|
||||
used_indices[modality] += 1
|
||||
else:
|
||||
sorted_mm_inputs = [
|
||||
orig_sorted_mm_inputs = [
|
||||
MultiModalKwargs.from_items([item]) for item in
|
||||
decoder_mm_inputs.get_items(sorted_item_modalities[0])
|
||||
]
|
||||
|
||||
if sorted_mm_hashes is not None:
|
||||
sorted_mm_inputs = self.mm_input_cache_client.get_and_update_p0(
|
||||
orig_sorted_mm_inputs, sorted_mm_hashes)
|
||||
else:
|
||||
sorted_mm_inputs = orig_sorted_mm_inputs
|
||||
|
||||
return EngineCoreRequest(
|
||||
request_id=request_id,
|
||||
prompt=decoder_inputs.get("prompt"),
|
||||
|
||||
Reference in New Issue
Block a user