[VLM] Generalized prompt updates for multi-modal processor (#13964)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-02-28 01:44:25 +08:00
committed by GitHub
parent 7864875879
commit f1579b229d
29 changed files with 629 additions and 486 deletions

View File

@@ -1,9 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
import math
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Final, Iterable, List, Literal, Mapping, Optional,
Protocol, Set, Tuple, TypedDict, Union)
from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple,
TypedDict, Union)
import torch
import torch.nn as nn
@@ -22,7 +23,7 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.parse import (ImageSize, MultiModalDataItems,
VideoEmbeddingItems, VideoProcessorItems)
from vllm.multimodal.processing import PromptReplacement
from vllm.multimodal.processing import PromptReplacement, PromptUpdate
from vllm.multimodal.profiling import ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.utils import is_list_of
@@ -347,13 +348,13 @@ class LlavaOnevisionMultiModalProcessor(
)
return BatchFeature(combined_outputs)
def _hf_processor_applies_repl(
def _hf_processor_applies_updates(
self,
prompt_text: str,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
) -> bool:
base_result = super()._hf_processor_applies_repl(
base_result = super()._hf_processor_applies_updates(
prompt_text=prompt_text,
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
@@ -361,13 +362,13 @@ class LlavaOnevisionMultiModalProcessor(
return base_result and mm_items.get_count("video", strict=False) == 0
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
image_repls = super()._get_prompt_replacements(
) -> Sequence[PromptUpdate]:
image_repls = super()._get_prompt_updates(
mm_items=mm_items,
hf_processor_mm_kwargs=hf_processor_mm_kwargs,
out_mm_kwargs=out_mm_kwargs,
@@ -392,7 +393,8 @@ class LlavaOnevisionMultiModalProcessor(
return [video_token_id] * num_video_tokens
return image_repls + [
return [
*image_repls,
PromptReplacement(
modality="video",
target=[video_token_id],