[Bugfix] Avoid transferring cached multi-modal items from P0 to P1 (#16273)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-04-09 15:51:27 +08:00
committed by GitHub
parent 24f6b9a713
commit e484e02857
5 changed files with 65 additions and 22 deletions

View File

@@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import time
from collections.abc import Mapping
from collections.abc import Mapping, Sequence
from typing import Literal, Optional, Union
from vllm.config import VllmConfig
@@ -19,6 +19,7 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
from vllm.v1.structured_output.backend_guidance import (
validate_guidance_grammar)
from vllm.v1.structured_output.utils import (
@@ -47,6 +48,8 @@ class Processor:
self.tokenizer,
mm_registry)
self.mm_input_cache_client = MirroredProcessingCache(self.model_config)
# Multi-modal hasher (for images)
self.use_hash = (
not self.model_config.disable_mm_preprocessor_cache) or \
@@ -231,7 +234,7 @@ class Processor:
self.tokenizer.get_lora_tokenizer(lora_request))
# Multimodal related.
sorted_mm_inputs: Optional[list[MultiModalKwargs]] = None
sorted_mm_inputs: Optional[Sequence[Optional[MultiModalKwargs]]] = None
sorted_mm_positions: Optional[list[PlaceholderRange]] = None
sorted_mm_hashes: Optional[list[str]] = None
if decoder_inputs["type"] == "multimodal":
@@ -256,20 +259,28 @@ class Processor:
# are multiple modalities.
unique_modalities = set(sorted_item_modalities)
if len(unique_modalities) > 1:
sorted_mm_inputs = []
orig_sorted_mm_inputs = []
used_indices = {modality: 0 for modality in unique_modalities}
for modality in sorted_item_modalities:
items = decoder_mm_inputs.get_items(modality)
item = items[used_indices[modality]]
sorted_mm_inputs.append(MultiModalKwargs.from_items([item
]))
orig_sorted_mm_inputs.append(
MultiModalKwargs.from_items([item]))
used_indices[modality] += 1
else:
sorted_mm_inputs = [
orig_sorted_mm_inputs = [
MultiModalKwargs.from_items([item]) for item in
decoder_mm_inputs.get_items(sorted_item_modalities[0])
]
if sorted_mm_hashes is not None:
sorted_mm_inputs = self.mm_input_cache_client.get_and_update_p0(
orig_sorted_mm_inputs, sorted_mm_hashes)
else:
sorted_mm_inputs = orig_sorted_mm_inputs
return EngineCoreRequest(
request_id=request_id,
prompt=decoder_inputs.get("prompt"),