[Bugfix] Fix multi-modal processors for transformers 4.48 (#12187)
This commit is contained in:
@@ -5,9 +5,11 @@ from typing import (Final, Iterable, List, Literal, Mapping, Optional,
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from packaging.version import Version
|
||||
from transformers import (BatchFeature, CLIPVisionConfig, LlavaConfig,
|
||||
PixtralVisionConfig, PretrainedConfig,
|
||||
SiglipVisionConfig)
|
||||
from transformers import __version__ as TRANSFORMERS_VERSION
|
||||
from transformers.models.llava import LlavaProcessor
|
||||
from transformers.models.pixtral import PixtralProcessor
|
||||
|
||||
@@ -716,6 +718,27 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
return loader.load_weights(weights)
|
||||
|
||||
|
||||
class MantisProcessingInfo(LlavaProcessingInfo):
|
||||
|
||||
def get_hf_processor(self):
|
||||
hf_config = self.get_hf_config()
|
||||
vision_info = self.get_vision_encoder_info()
|
||||
|
||||
if Version(TRANSFORMERS_VERSION) < Version("4.48"):
|
||||
# BUG: num_additional_image_tokens = 0 but treated as 1,
|
||||
# so we set vision_feature_select_strategy to None to offset this
|
||||
vision_feature_select_strategy = None
|
||||
else:
|
||||
# FIXED: https://github.com/huggingface/transformers/pull/33424/files#diff-6a37acc21efcadaae622b079b2712a131131448ff64262bd219aa346aeec38faL150
|
||||
vision_feature_select_strategy = hf_config.vision_feature_select_strategy # noqa: E501
|
||||
|
||||
return self.ctx.get_hf_processor(
|
||||
LlavaProcessor,
|
||||
patch_size=vision_info.get_patch_size(),
|
||||
vision_feature_select_strategy=vision_feature_select_strategy,
|
||||
)
|
||||
|
||||
|
||||
class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
|
||||
def apply(
|
||||
@@ -794,7 +817,7 @@ class MantisMultiModalProcessor(LlavaMultiModalProcessor):
|
||||
# To use this model, please use
|
||||
# `--hf_overrides '{"architectures": ["MantisForConditionalGeneration"]}'`
|
||||
@MULTIMODAL_REGISTRY.register_processor(MantisMultiModalProcessor,
|
||||
info=LlavaProcessingInfo,
|
||||
info=MantisProcessingInfo,
|
||||
dummy_inputs=LlavaDummyInputsBuilder)
|
||||
class MantisForConditionalGeneration(LlavaForConditionalGeneration):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user