[Bugfix] Fix multi-modal processors for transformers 4.48 (#12187)
This commit is contained in:
@@ -5,9 +5,11 @@ from typing import (Final, Iterable, List, Literal, Mapping, Optional,
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from packaging.version import Version
|
||||
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
|
||||
PixtralVisionConfig, PretrainedConfig,
|
||||
SiglipVisionConfig)
|
||||
from transformers import __version__ as TRANSFORMERS_VERSION
|
||||
from transformers.models.llava import LlavaProcessor
|
||||
from transformers.models.pixtral import PixtralProcessor
|
||||
|
||||
@@ -716,6 +718,27 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
return loader.load_weights(weights)
|
||||
|
||||
|
||||
class MantisProcessingInfo(LlavaProcessingInfo):
|
||||
|
||||
def get_hf_processor(self):
|
||||
hf_config = self.get_hf_config()
|
||||
vision_info = self.get_vision_encoder_info()
|
||||
|
||||
if Version(TRANSFORMERS_VERSION) < Version("4.48"):
|
||||
# BUG: num_additional_image_tokens = 0 but treated as 1,
|
||||
# so we set vision_feature_select_strategy to None to offset this
|
||||
vision_feature_select_strategy = None
|
||||
else:
|
||||
# FIXED: https://github.com/huggingface/transformers/pull/33424/files#diff-6a37acc21efcadaae622b079b2712a131131448ff64262bd219aa346aeec38faL150
|
||||
vision_feature_select_strategy = hf_config.vision_feature_select_strategy # noqa: E501
|
||||
|
||||
return self.ctx.get_hf_processor(
|
||||
LlavaProcessor,
|
||||
patch_size=vision_info.get_patch_size(),
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
)
|
||||
|
||||
|
||||
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
|
||||
def apply(
|
||||
@@ -794,7 +817,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
# To use this model, please use
|
||||
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
|
||||
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor,
|
||||
info=LlavaProcessingInfo,
|
||||
info=MantisProcessingInfo,
|
||||
dummy_inputs=LlavaDummyInputsBuilder)
|
||||
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
|
||||
pass
|
||||
|
||||
@@ -36,8 +36,9 @@ from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.inputs import (MultiModalFieldConfig, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||
MultiModalInputsV2, MultiModalKwargs,
|
||||
NestedTensors, PlaceholderRange)
|
||||
from vllm.multimodal.parse import (AudioProcessorItems, MultiModalDataItems,
|
||||
MultiModalDataParser)
|
||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
@@ -153,29 +154,24 @@ class Qwen2AudioMultiModalProcessor(
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, Any],
|
||||
) -> BatchFeature:
|
||||
mm_data = dict(mm_data)
|
||||
audios = mm_data.pop("audios", [])
|
||||
# Text-only input not supported in composite processor
|
||||
if not mm_data or not mm_data.get("audios", []):
|
||||
prompt_ids = self.info.get_tokenizer().encode(prompt)
|
||||
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
|
||||
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
|
||||
|
||||
if audios:
|
||||
mm_data["audios"] = audios
|
||||
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
|
||||
mm_kwargs = dict(
|
||||
**mm_kwargs,
|
||||
sampling_rate=feature_extractor.sampling_rate,
|
||||
)
|
||||
|
||||
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
|
||||
mm_kwargs = dict(
|
||||
**mm_kwargs,
|
||||
sampling_rate=feature_extractor.sampling_rate,
|
||||
)
|
||||
else:
|
||||
# NOTE: WhisperFeatureExtractor cannot handle empty list of audios
|
||||
pass
|
||||
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
return super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
@@ -192,8 +188,14 @@ class Qwen2AudioMultiModalProcessor(
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
hf_config = self.info.get_hf_config()
|
||||
placeholder = hf_config.audio_token_index
|
||||
processor = self.info.get_hf_processor()
|
||||
|
||||
# Use getattr with default to be compatible with transformers<4.48
|
||||
audio_token = getattr(processor, "audio_token", "<|AUDIO|>")
|
||||
audio_bos_token = getattr(processor, "audio_bos_token",
|
||||
"<|audio_bos|>")
|
||||
audio_eos_token = getattr(processor, "audio_eos_token",
|
||||
"<|audio_eos|>")
|
||||
|
||||
feature_attention_mask = out_mm_kwargs.get("feature_attention_mask")
|
||||
if feature_attention_mask is None:
|
||||
@@ -214,12 +216,16 @@ class Qwen2AudioMultiModalProcessor(
|
||||
f"The audio {audio} (len={len(audio)}) is too short "
|
||||
"to be represented inside the model")
|
||||
|
||||
return [placeholder] * num_placeholders
|
||||
return "".join([
|
||||
audio_bos_token,
|
||||
audio_token * num_placeholders,
|
||||
audio_eos_token,
|
||||
])
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="audio",
|
||||
target=[placeholder],
|
||||
target=audio_token,
|
||||
replacement=get_replacement_qwen2_audio,
|
||||
)
|
||||
]
|
||||
@@ -234,6 +240,26 @@ class Qwen2AudioMultiModalProcessor(
|
||||
# tokens than the number of audio items)
|
||||
return not hasattr(self.info.get_hf_processor(), "audio_token")
|
||||
|
||||
def apply(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
mm_data: MultiModalDataDict,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> MultiModalInputsV2:
|
||||
result = super().apply(prompt, mm_data, hf_processor_mm_kwargs)
|
||||
|
||||
# Only <|AUDIO|> tokens should be considered as placeholders,
|
||||
# so we ignore the audio_bos_token and audio_eos_token
|
||||
result["mm_placeholders"] = {
|
||||
modality: [
|
||||
PlaceholderRange(offset=p["offset"] + 1,
|
||||
length=p["length"] - 2) for p in ps
|
||||
]
|
||||
for modality, ps in result["mm_placeholders"].items()
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_processor(
|
||||
Qwen2AudioMultiModalProcessor,
|
||||
|
||||
@@ -137,7 +137,7 @@ class UltravoxMultiModalProcessor(
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
# Text-only input not supported in composite processor
|
||||
if not mm_data:
|
||||
if not mm_data or not mm_data.get("audios", []):
|
||||
prompt_ids = self.info.get_tokenizer().encode(prompt)
|
||||
prompt_ids = self._apply_hf_processor_tokens_only(prompt_ids)
|
||||
return BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
|
||||
@@ -146,13 +146,6 @@ class UltravoxMultiModalProcessor(
|
||||
audios = mm_data.pop("audios", [])
|
||||
assert isinstance(audios, list)
|
||||
|
||||
if not audios:
|
||||
return super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
|
||||
feature_extractor = self.info.get_feature_extractor()
|
||||
mm_kwargs = dict(
|
||||
**mm_kwargs,
|
||||
|
||||
Reference in New Issue
Block a user