[Model] Upgrade Aria to transformers 4.48 (#12203)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-01-20 17:58:48 +08:00
committed by GitHub
parent 3127e975fb
commit b37d82791e
10 changed files with 178 additions and 379 deletions

View File

@@ -1,9 +1,11 @@
from typing import (Callable, Iterable, List, Mapping, Optional, Set, Tuple,
TypedDict, Union)
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
Union)
import torch
import torch.nn as nn
from transformers import BatchFeature, PretrainedConfig
from transformers import AriaConfig, AriaTextConfig, BatchFeature
from transformers.models.aria.modeling_aria import AriaCrossAttention
from transformers.models.aria.processing_aria import AriaProcessor
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, QuantizationConfig, VllmConfig
@@ -26,10 +28,11 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.configs.aria import (AriaMoELMConfig,
AriaVisionConfig)
from .idefics2_vision_model import Idefics2VisionTransformer
# yapf: disable
from .idefics2_vision_model import (
Idefics2VisionTransformer as Idefics3VisionTransformer)
# yapf: enable
from .interfaces import SupportsMultiModal
from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
@@ -47,87 +50,22 @@ class AriaImagePixelInputs(TypedDict):
"""
class AriaVisionTransformer(Idefics2VisionTransformer):
"""
AriaVisionTransformer is a modified version of Idefics2VisionTransformer
that replaces the post-layernorm with an identity layer.
"""
class AriaProjectorMLP(nn.Module):
def __init__(
self,
config: AriaVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config, quant_config, prefix)
self.post_layernorm = nn.Identity()
class AriaVisionModel(nn.Module):
config_class = AriaVisionConfig
def __init__(
self,
config: AriaVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
*,
prefix: str = "",
in_features: int,
hidden_features: int,
output_dim: int,
) -> None:
super().__init__()
self.vision_model = AriaVisionTransformer(
config,
quant_config,
prefix=f"{prefix}.vision_model",
)
def forward(
self,
pixel_values: torch.Tensor,
pixel_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
vit_oup = self.vision_model(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
)
image_atts = self._create_image_attention_mask(patch_attention_mask)
return vit_oup, image_atts
def _create_patch_attention_mask(
self, pixel_mask: Optional[torch.Tensor]) -> torch.Tensor:
if pixel_mask is None:
return None
patches_subgrid = pixel_mask.unfold(
dimension=1,
size=self.vision_model.config.patch_size,
step=self.vision_model.config.patch_size,
).unfold(
dimension=2,
size=self.vision_model.config.patch_size,
step=self.vision_model.config.patch_size,
)
return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
def _create_image_attention_mask(
self, patch_attention_mask: torch.Tensor) -> torch.Tensor:
if patch_attention_mask is None:
return None
flattened_mask = patch_attention_mask.flatten(1)
return torch.logical_not(flattened_mask)
class FFN(nn.Module):
def __init__(self, embed_dim: int, ff_dim: int, output_dim: int) -> None:
super().__init__()
self.linear_in = ColumnParallelLinear(embed_dim, ff_dim, bias=False)
self.linear_out = RowParallelLinear(ff_dim, output_dim, bias=False)
self.linear_in = ColumnParallelLinear(in_features,
hidden_features,
bias=False)
self.linear_out = RowParallelLinear(hidden_features,
output_dim,
bias=False)
self.act = get_act_fn("gelu_new")
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -137,46 +75,6 @@ class FFN(nn.Module):
return hidden_states
class CrossAttention(nn.Module):
def __init__(self, kv_dim: int, embed_dim: int, num_heads: int) -> None:
super().__init__()
self.num_heads = num_heads
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.k_proj = nn.Linear(kv_dim, embed_dim, bias=False)
self.v_proj = nn.Linear(kv_dim, embed_dim, bias=False)
self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
self.linear = nn.Linear(embed_dim, embed_dim)
self.layer_norm = nn.LayerNorm(embed_dim)
self.ln_kv = nn.LayerNorm(kv_dim)
def forward(
self,
x: torch.Tensor,
hidden_states: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
normed_hidden_states = self.layer_norm(hidden_states)
query = self.q_proj(normed_hidden_states).permute(1, 0, 2)
x = self.ln_kv(x)
key = self.k_proj(x).permute(1, 0, 2)
value = self.v_proj(x).permute(1, 0, 2)
attn_output, _ = self.multihead_attn(query,
key,
value,
attn_mask=attn_mask)
attn_output = attn_output.permute(1, 0, 2)
attn_output = self.linear(attn_output)
return attn_output
class AriaProjector(nn.Module):
"""
A projection module with one cross attention layer and one FFN layer, which
@@ -198,42 +96,42 @@ class AriaProjector(nn.Module):
A tensor with the shape of (batch_size, query_number, output_dim)
"""
def __init__(
self,
patch_to_query_dict: dict[int, int],
embed_dim: int,
num_heads: int,
kv_dim: int,
ff_dim: int,
output_dim: int,
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
) -> None:
def __init__(self, config: AriaConfig) -> None:
super().__init__()
self.patch_to_query_dict = patch_to_query_dict
self.embed_dim = embed_dim
self.num_heads = num_heads
self.patch_to_query_dict = config.projector_patch_to_query_dict
self.in_features = config.vision_config.hidden_size
self.num_heads = config.vision_config.num_attention_heads
self.kv_dim = config.vision_config.hidden_size
self.hidden_features = config.text_config.hidden_size
self.output_dim = config.text_config.hidden_size
self.query = nn.Parameter(
torch.empty(max(patch_to_query_dict.values()), self.embed_dim))
torch.empty(config.max_value_projector_patch_to_query_dict,
self.in_features))
self.cross_attn = CrossAttention(kv_dim, embed_dim, num_heads)
self.cross_attn = AriaCrossAttention(config)
self.ln_ffn = norm_layer(embed_dim)
self.ffn = FFN(embed_dim, ff_dim, output_dim)
self.layer_norm = nn.LayerNorm(self.in_features)
self.feed_forward = AriaProjectorMLP(self.in_features,
self.hidden_features,
self.output_dim)
def forward(
self,
x: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
bs = x.shape[0]
queries = self.query.unsqueeze(0).repeat(bs, 1, 1)
batch_size, num_patches = x.shape[0], x.shape[1]
query_num = self.patch_to_query_dict.get(x.shape[1], None)
assert (query_num is not None
), f"Query number for {x.shape[1]} patches is not provided"
if num_patches not in self.patch_to_query_dict:
raise KeyError(f"Number of patches {num_patches} not found in "
"patch_to_query_dict amongst possible values "
f"{self.patch_to_query_dict.keys()}.")
queries = queries[:, :query_num, :]
query_num = self.patch_to_query_dict[num_patches]
queries = self.query[:query_num].unsqueeze(0).repeat(batch_size, 1, 1)
if attn_mask is not None:
attn_mask = attn_mask.repeat_interleave(self.num_heads, 0)
@@ -241,7 +139,7 @@ class AriaProjector(nn.Module):
attention_out = self.cross_attn(x, queries, attn_mask=attn_mask)
out = self.ffn(self.ln_ffn(attention_out))
out = self.feed_forward(self.layer_norm(attention_out))
return out
@@ -278,7 +176,7 @@ class AriaFusedMoE(FusedMoE):
param.data.copy_(loaded_weight.transpose(1, 2))
class MoELayer(nn.Module):
class AriaTextMoELayer(nn.Module):
"""
Mixture of Experts (MoE) Layer for the AriaMoE model.
@@ -289,7 +187,7 @@ class MoELayer(nn.Module):
def __init__(
self,
config: AriaMoELMConfig,
config: AriaTextConfig,
quant_config: Optional[QuantizationConfig],
) -> None:
super().__init__()
@@ -303,15 +201,16 @@ class MoELayer(nn.Module):
num_experts=config.moe_num_experts,
top_k=config.moe_topk,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config,
reduce_results=True,
)
self.shared_experts = LlamaMLP(
config.hidden_size,
config.moe_intermediate_size * config.moe_num_shared_experts,
config.intermediate_size * config.moe_num_shared_experts,
"silu",
quant_config=quant_config,
bias=config.mlp_bias,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -329,13 +228,13 @@ class MoELayer(nn.Module):
router_output = torch.nn.functional.linear(hidden_states,
self.router_weight)
shared_expert_output = self.shared_experts(hidden_states)
sparse_expert_output = self.experts(hidden_states, router_output)
shared_expert_output = self.shared_experts(hidden_states)
return sparse_expert_output + shared_expert_output
class MoEDecoderLayer(LlamaDecoderLayer):
class AriaTextDecoderLayer(LlamaDecoderLayer):
"""
Custom Decoder Layer for the AriaMoE model which modifies the standard
`LlamaDecoderLayer` by replacing the traditional MLP with a Mixture of
@@ -344,16 +243,16 @@ class MoEDecoderLayer(LlamaDecoderLayer):
def __init__(
self,
config: AriaMoELMConfig,
config: AriaTextConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(config, cache_config, quant_config, prefix)
self.mlp = MoELayer(config, quant_config=quant_config)
self.mlp = AriaTextMoELayer(config, quant_config=quant_config)
class AriaMoELMModel(LlamaModel):
class AriaTextModel(LlamaModel):
"""
Custom LlamaModel for the AriaMoE model which modifies the standard
LlamaModel by replacing the `LlamaDecoderLayer` with `MoEDecoderLayer`.
@@ -362,7 +261,7 @@ class AriaMoELMModel(LlamaModel):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__(vllm_config=vllm_config,
prefix=prefix,
layer_type=MoEDecoderLayer)
layer_type=AriaTextDecoderLayer)
# Adapted from LlamaModel.load_weights with the modification of adding
# the expert weights mapping to `stacked_params_mapping`
@@ -434,25 +333,23 @@ class AriaMoELMModel(LlamaModel):
return loaded_params
def build_mm_projector(config: PretrainedConfig):
return AriaProjector(
patch_to_query_dict=config.projector_patch_to_query_dict,
embed_dim=config.vision_config.hidden_size,
num_heads=config.vision_config.num_attention_heads,
kv_dim=config.vision_config.hidden_size,
ff_dim=config.text_config.hidden_size,
output_dim=config.text_config.hidden_size,
)
class AriaProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config()
return self.ctx.get_hf_config(AriaConfig)
def get_vision_config(self) -> AriaVisionConfig:
def get_vision_config(self):
return self.get_hf_config().vision_config
def get_hf_processor(self):
processor = self.ctx.get_hf_processor(AriaProcessor)
# Patch for https://github.com/huggingface/transformers/issues/35768
processor.tokenizer.image_token = "<|img|>"
processor.image_token = "<|img|>"
return processor
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
@@ -554,10 +451,14 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
quant_config = vllm_config.quant_config
self.config = config
self.vision_tower = AriaVisionModel(config.vision_config)
self.multi_modal_projector = build_mm_projector(config)
self.vision_tower = Idefics3VisionTransformer(
config.vision_config,
quant_config,
prefix=f"{prefix}.vision_tower",
)
self.multi_modal_projector = AriaProjector(config)
self.vocab_size = config.text_config.vocab_size
self.language_model = AriaMoELMModel(
self.language_model = AriaTextModel(
vllm_config=vllm_config.with_hf_config(config.text_config),
prefix=maybe_prefix(prefix, "language_model.model"),
)
@@ -608,6 +509,22 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
pixel_mask=pixel_mask,
)
def _create_patch_attention_mask(
self, pixel_mask: Optional[torch.Tensor]) -> torch.Tensor:
if pixel_mask is None:
return None
patches_subgrid = pixel_mask.unfold(
dimension=1,
size=self.vision_tower.config.patch_size,
step=self.vision_tower.config.patch_size,
).unfold(
dimension=2,
size=self.vision_tower.config.patch_size,
step=self.vision_tower.config.patch_size,
)
return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
def _process_image_input(
self, image_input: AriaImagePixelInputs
) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -616,9 +533,18 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
pixel_values = image_input['pixel_values']
pixel_mask = image_input['pixel_mask']
image_feature, image_attn_mask = self.vision_tower(
pixel_values, pixel_mask=pixel_mask)
return self.multi_modal_projector(image_feature, image_attn_mask)
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
image_outputs = self.vision_tower(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
)
image_attn_mask = None
if patch_attention_mask is not None:
flattened_mask = patch_attention_mask.flatten(1)
image_attn_mask = torch.logical_not(flattened_mask)
return self.multi_modal_projector(image_outputs, image_attn_mask)
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
@@ -683,6 +609,5 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)