[Bugfix] Fix multi-modal processors for transformers 4.48 (#12187)

This commit is contained in:
Cyrus Leung
2025-01-19 11:16:34 +08:00
committed by GitHub
parent 4e94951bb1
commit 630eb5b5ce
6 changed files with 198 additions and 35 deletions

View File

@@ -1,3 +1,4 @@
from vllm.transformers_utils.configs.aria import AriaConfig
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
from vllm.transformers_utils.configs.cohere2 import Cohere2Config
from vllm.transformers_utils.configs.dbrx import DbrxConfig
@@ -23,6 +24,7 @@ from vllm.transformers_utils.configs.telechat2 import Telechat2Config
from vllm.transformers_utils.configs.ultravox import UltravoxConfig
__all__ = [
"AriaConfig",
"ChatGLMConfig",
"Cohere2Config",
"DbrxConfig",

View File

@@ -1,7 +1,32 @@
# Copyright 2024 Rhymes AI. All rights reserved.
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Mapping
from transformers import PretrainedConfig
from transformers.models.idefics2.configuration_idefics2 import (
Idefics2VisionConfig)
from transformers.models.llama.configuration_llama import LlamaConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
class AriaVisionConfig(Idefics2VisionConfig):
model_type = "aria_vision_model"
@@ -45,3 +70,96 @@ class AriaMoELMConfig(LlamaConfig):
self.moe_num_experts = moe_num_experts
self.moe_topk = moe_topk
self.moe_num_shared_experts = moe_num_shared_experts
class AriaConfig(PretrainedConfig):
"""
Configuration class for Aria model.
This class handles the configuration for both vision and text components of
the Aria model,
as well as additional parameters for image token handling and projector
mapping.
Args:
vision_config (AriaVisionConfig or dict): Configuration for the vision
component.
text_config (AriaMoELMConfig or dict): Configuration for the text
component.
projector_patch_to_query_dict (dict): Mapping of patch sizes to query
dimensions.
ignore_index (int): Index to ignore in loss calculation.
image_token_index (int): Index used to represent image tokens.
**kwargs: Additional keyword arguments passed to the parent class.
Attributes:
model_type (str): Type of the model, set to "aria".
is_composition (bool): Whether the model is a composition of multiple
components.
ignore_index (int): Index to ignore in loss calculation.
image_token_index (int): Index used to represent image tokens.
projector_patch_to_query_dict (dict): Mapping of patch sizes to query
dimensions.
vision_config (AriaVisionConfig): Configuration for the vision
component.
text_config (AriaMoELMConfig): Configuration for the text component.
"""
model_type = "aria"
is_composition = False
def __init__(
self,
vision_config: AriaVisionConfig = AriaVisionConfig(), # noqa: B008
text_config: AriaMoELMConfig = AriaMoELMConfig(), # noqa: B008
projector_patch_to_query_dict: Mapping[int, int] = {
1225: 128,
4900: 256,
},
ignore_index=-100,
image_token_index=32000,
tie_word_embeddings=False,
**kwargs,
):
super().__init__(**kwargs)
self.ignore_index = ignore_index
self.image_token_index = image_token_index
self.tie_word_embeddings = tie_word_embeddings
attn_implementation = kwargs.pop("attn_implementation", None)
# Set the default attention implementation to flash_attention_2 if not
# specified
self._attn_implementation = ("flash_attention_2"
if attn_implementation is None else
attn_implementation)
# Convert the keys and values of projector_patch_to_query_dict to
# integers
# This ensures consistency even if they were provided as strings
self.projector_patch_to_query_dict = {
int(k): int(v)
for k, v in projector_patch_to_query_dict.items()
}
if isinstance(vision_config, dict) and "model_type" in vision_config:
vision_config = AriaVisionConfig(**vision_config)
if attn_implementation is None:
vision_attn_implementation = "flash_attention_2"
elif attn_implementation == "sdpa":
logger.warning("SDPA is not supported for vit, using "
"flash_attention_2 instead")
vision_attn_implementation = "flash_attention_2"
else:
vision_attn_implementation = attn_implementation
vision_config._attn_implementation = vision_attn_implementation
self.vision_config = vision_config
if isinstance(text_config, dict) and "model_type" in text_config:
text_attn_implementation = ("sdpa" if attn_implementation is None
else attn_implementation)
text_config = AriaMoELMConfig(**text_config)
text_config._attn_implementation = text_attn_implementation
self.text_config = text_config
# This is needed for the static kv cache
self.num_hidden_layers = self.text_config.num_hidden_layers