[VLM] Generalized prompt updates for multi-modal processor (#13964)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-02-28 01:44:25 +08:00
committed by GitHub
parent 7864875879
commit f1579b229d
29 changed files with 629 additions and 486 deletions

View File

@@ -1,8 +1,8 @@
# SPDX-License-Identifier: Apache-2.0
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import (Any, Dict, Iterable, Literal, Mapping, Optional, Set,
Tuple, TypedDict, Union)
from typing import Any, Dict, Literal, Optional, Set, Tuple, TypedDict, Union
import torch
import torch.nn as nn
@@ -35,7 +35,7 @@ from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
from vllm.multimodal.parse import MultiModalDataItems
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement,
PromptReplacementDetails)
PromptUpdate, PromptUpdateDetails)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
@@ -141,12 +141,12 @@ class ChameleonMultiModalProcessor(
) -> Mapping[str, MultiModalFieldConfig]:
return dict(pixel_values=MultiModalFieldConfig.batched("image"))
def _get_prompt_replacements(
def _get_prompt_updates(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
) -> Sequence[PromptUpdate]:
processor = self.info.get_hf_processor(**hf_processor_mm_kwargs)
tokenizer = self.info.get_tokenizer()
vocab = tokenizer.get_vocab()
@@ -162,7 +162,7 @@ class ChameleonMultiModalProcessor(
PromptReplacement(
modality="image",
target=[image_token_id],
replacement=PromptReplacementDetails(
replacement=PromptUpdateDetails(
full=([image_start_id] + image_tokens + [image_end_id]),
features=image_tokens,
),
@@ -371,7 +371,7 @@ class ChameleonDecoderLayer(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
if residual is None:
residual = hidden_states