[Bugfix] Fix multi-modal processors for transformers 4.48 (#12187)

This commit is contained in:
Cyrus Leung
2025-01-19 11:16:34 +08:00
committed by GitHub
parent 4e94951bb1
commit 630eb5b5ce
6 changed files with 198 additions and 35 deletions

View File

@@ -5,9 +5,11 @@ from typing import (Final, Iterable, List, Literal, Mapping, Optional,
import torch
import torch.nn as nn
from packaging.version import Version
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
PixtralVisionConfig, PretrainedConfig,
SiglipVisionConfig)
from transformers import __version__ as TRANSFORMERS_VERSION
from transformers.models.llava import LlavaProcessor
from transformers.models.pixtral import PixtralProcessor
@@ -716,6 +718,27 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
return loader.load_weights(weights)
class MantisProcessingInfo(LlavaProcessingInfo):
def get_hf_processor(self):
hf_config = self.get_hf_config()
vision_info = self.get_vision_encoder_info()
if Version(TRANSFORMERS_VERSION) < Version("4.48"):
# BUG: num_additional_image_tokens = 0 but treated as 1,
# so we set vision_feature_select_strategy to None to offset this
vision_feature_select_strategy = None
else:
# FIXED: https://github.com/huggingface/transformers/pull/33424/files#diff-6a37acc21efcadaae622b079b2712a131131448ff64262bd219aa346aeec38faL150
vision_feature_select_strategy = hf_config.vision_feature_select_strategy # noqa: E501
return self.ctx.get_hf_processor(
LlavaProcessor,
patch_size=vision_info.get_patch_size(),
vision_feature_select_strategy=vision_feature_select_strategy,
)
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
def apply(
@@ -794,7 +817,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
# To use this model, please use
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor,
info=LlavaProcessingInfo,
info=MantisProcessingInfo,
dummy_inputs=LlavaDummyInputsBuilder)
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
pass

View File

@@ -36,8 +36,9 @@ from vllm.config import VllmConfig
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalInputsV2, MultiModalKwargs,
NestedTensors, PlaceholderRange)
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseMultiModalProcessor,
@@ -153,29 +154,24 @@ class Qwen2AudioMultiModalProcessor(
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, Any],
) -> BatchFeature:
mm_data = dict(mm_data)
audios = mm_data.pop("audios", [])
# Text-only input not supported in composite processor
if not mm_data or not mm_data.get("audios", []):
prompt_ids = self.info.get_tokenizer().encode(prompt)
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
if audios:
mm_data["audios"] = audios
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
mm_kwargs = dict(
**mm_kwargs,
sampling_rate=feature_extractor.sampling_rate,
)
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
mm_kwargs = dict(
**mm_kwargs,
sampling_rate=feature_extractor.sampling_rate,
)
else:
# NOTE: WhisperFeatureExtractor cannot handle empty list of audios
pass
processed_outputs = super()._call_hf_processor(
return super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
@@ -192,8 +188,14 @@ class Qwen2AudioMultiModalProcessor(
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
hf_config = self.info.get_hf_config()
placeholder = hf_config.audio_token_index
processor = self.info.get_hf_processor()
# Use getattr with default to be compatible with transformers<4.48
audio_token = getattr(processor, "audio_token", "<|AUDIO|>")
audio_bos_token = getattr(processor, "audio_bos_token",
"<|audio_bos|>")
audio_eos_token = getattr(processor, "audio_eos_token",
"<|audio_eos|>")
feature_attention_mask = out_mm_kwargs.get("feature_attention_mask")
if feature_attention_mask is None:
@@ -214,12 +216,16 @@ class Qwen2AudioMultiModalProcessor(
f"The audio {audio} (len={len(audio)}) is too short "
"to be represented inside the model")
return [placeholder] * num_placeholders
return "".join([
audio_bos_token,
audio_token * num_placeholders,
audio_eos_token,
])
return [
PromptReplacement(
modality="audio",
target=[placeholder],
target=audio_token,
replacement=get_replacement_qwen2_audio,
)
]
@@ -234,6 +240,26 @@ class Qwen2AudioMultiModalProcessor(
# tokens than the number of audio items)
return not hasattr(self.info.get_hf_processor(), "audio_token")
def apply(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
hf_processor_mm_kwargs: Mapping[str, object],
) -> MultiModalInputsV2:
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
# Only <|AUDIO|> tokens should be considered as placeholders,
# so we ignore the audio_bos_token and audio_eos_token
result["mm_placeholders"] = {
modality: [
PlaceholderRange(offset=p["offset"] + 1,
length=p["length"] - 2) for p in ps
]
for modality, ps in result["mm_placeholders"].items()
}
return result
@MULTIMODAL_REGISTRY.register_processor(
Qwen2AudioMultiModalProcessor,

View File

@@ -137,7 +137,7 @@ class UltravoxMultiModalProcessor(
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
# Text-only input not supported in composite processor
if not mm_data:
if not mm_data or not mm_data.get("audios", []):
prompt_ids = self.info.get_tokenizer().encode(prompt)
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
@@ -146,13 +146,6 @@ class UltravoxMultiModalProcessor(
audios = mm_data.pop("audios", [])
assert isinstance(audios, list)
if not audios:
return super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
feature_extractor = self.info.get_feature_extractor()
mm_kwargs = dict(
**mm_kwargs,