[Model] Upgrade Aria to transformers 4.48 (#12203)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -1,9 +1,11 @@
|
||||
from typing import (Callable, Iterable, List, Mapping, Optional, Set, Tuple,
|
||||
TypedDict, Union)
|
||||
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
|
||||
Union)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import BatchFeature, PretrainedConfig
|
||||
from transformers import AriaConfig, AriaTextConfig, BatchFeature
|
||||
from transformers.models.aria.modeling_aria import AriaCrossAttention
|
||||
from transformers.models.aria.processing_aria import AriaProcessor
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, QuantizationConfig, VllmConfig
|
||||
@@ -26,10 +28,11 @@ from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||
BaseProcessingInfo, PromptReplacement)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.configs.aria import (AriaMoELMConfig,
|
||||
AriaVisionConfig)
|
||||
|
||||
from .idefics2_vision_model import Idefics2VisionTransformer
|
||||
# yapf: disable
|
||||
from .idefics2_vision_model import (
|
||||
Idefics2VisionTransformer as Idefics3VisionTransformer)
|
||||
# yapf: enable
|
||||
from .interfaces import SupportsMultiModal
|
||||
from .llama import LlamaDecoderLayer, LlamaMLP, LlamaModel
|
||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
||||
@@ -47,87 +50,22 @@ class AriaImagePixelInputs(TypedDict):
|
||||
"""
|
||||
|
||||
|
||||
class AriaVisionTransformer(Idefics2VisionTransformer):
|
||||
"""
|
||||
AriaVisionTransformer is a modified version of Idefics2VisionTransformer
|
||||
that replaces the post-layernorm with an identity layer.
|
||||
"""
|
||||
class AriaProjectorMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: AriaVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(config, quant_config, prefix)
|
||||
self.post_layernorm = nn.Identity()
|
||||
|
||||
|
||||
class AriaVisionModel(nn.Module):
|
||||
config_class = AriaVisionConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: AriaVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
in_features: int,
|
||||
hidden_features: int,
|
||||
output_dim: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.vision_model = AriaVisionTransformer(
|
||||
config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.vision_model",
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: torch.Tensor,
|
||||
pixel_mask: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
|
||||
|
||||
vit_oup = self.vision_model(
|
||||
pixel_values=pixel_values,
|
||||
patch_attention_mask=patch_attention_mask,
|
||||
)
|
||||
|
||||
image_atts = self._create_image_attention_mask(patch_attention_mask)
|
||||
|
||||
return vit_oup, image_atts
|
||||
|
||||
def _create_patch_attention_mask(
|
||||
self, pixel_mask: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
if pixel_mask is None:
|
||||
return None
|
||||
|
||||
patches_subgrid = pixel_mask.unfold(
|
||||
dimension=1,
|
||||
size=self.vision_model.config.patch_size,
|
||||
step=self.vision_model.config.patch_size,
|
||||
).unfold(
|
||||
dimension=2,
|
||||
size=self.vision_model.config.patch_size,
|
||||
step=self.vision_model.config.patch_size,
|
||||
)
|
||||
return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||
|
||||
def _create_image_attention_mask(
|
||||
self, patch_attention_mask: torch.Tensor) -> torch.Tensor:
|
||||
if patch_attention_mask is None:
|
||||
return None
|
||||
|
||||
flattened_mask = patch_attention_mask.flatten(1)
|
||||
return torch.logical_not(flattened_mask)
|
||||
|
||||
|
||||
class FFN(nn.Module):
|
||||
|
||||
def __init__(self, embed_dim: int, ff_dim: int, output_dim: int) -> None:
|
||||
super().__init__()
|
||||
self.linear_in = ColumnParallelLinear(embed_dim, ff_dim, bias=False)
|
||||
self.linear_out = RowParallelLinear(ff_dim, output_dim, bias=False)
|
||||
self.linear_in = ColumnParallelLinear(in_features,
|
||||
hidden_features,
|
||||
bias=False)
|
||||
self.linear_out = RowParallelLinear(hidden_features,
|
||||
output_dim,
|
||||
bias=False)
|
||||
self.act = get_act_fn("gelu_new")
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@@ -137,46 +75,6 @@ class FFN(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
|
||||
def __init__(self, kv_dim: int, embed_dim: int, num_heads: int) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
|
||||
self.k_proj = nn.Linear(kv_dim, embed_dim, bias=False)
|
||||
self.v_proj = nn.Linear(kv_dim, embed_dim, bias=False)
|
||||
|
||||
self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
|
||||
self.linear = nn.Linear(embed_dim, embed_dim)
|
||||
|
||||
self.layer_norm = nn.LayerNorm(embed_dim)
|
||||
self.ln_kv = nn.LayerNorm(kv_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
normed_hidden_states = self.layer_norm(hidden_states)
|
||||
query = self.q_proj(normed_hidden_states).permute(1, 0, 2)
|
||||
|
||||
x = self.ln_kv(x)
|
||||
key = self.k_proj(x).permute(1, 0, 2)
|
||||
value = self.v_proj(x).permute(1, 0, 2)
|
||||
|
||||
attn_output, _ = self.multihead_attn(query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attn_mask)
|
||||
|
||||
attn_output = attn_output.permute(1, 0, 2)
|
||||
|
||||
attn_output = self.linear(attn_output)
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class AriaProjector(nn.Module):
|
||||
"""
|
||||
A projection module with one cross attention layer and one FFN layer, which
|
||||
@@ -198,42 +96,42 @@ class AriaProjector(nn.Module):
|
||||
A tensor with the shape of (batch_size, query_number, output_dim)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_to_query_dict: dict[int, int],
|
||||
embed_dim: int,
|
||||
num_heads: int,
|
||||
kv_dim: int,
|
||||
ff_dim: int,
|
||||
output_dim: int,
|
||||
norm_layer: Callable[[int], nn.Module] = nn.LayerNorm,
|
||||
) -> None:
|
||||
def __init__(self, config: AriaConfig) -> None:
|
||||
super().__init__()
|
||||
self.patch_to_query_dict = patch_to_query_dict
|
||||
self.embed_dim = embed_dim
|
||||
self.num_heads = num_heads
|
||||
|
||||
self.patch_to_query_dict = config.projector_patch_to_query_dict
|
||||
self.in_features = config.vision_config.hidden_size
|
||||
self.num_heads = config.vision_config.num_attention_heads
|
||||
self.kv_dim = config.vision_config.hidden_size
|
||||
self.hidden_features = config.text_config.hidden_size
|
||||
self.output_dim = config.text_config.hidden_size
|
||||
|
||||
self.query = nn.Parameter(
|
||||
torch.empty(max(patch_to_query_dict.values()), self.embed_dim))
|
||||
torch.empty(config.max_value_projector_patch_to_query_dict,
|
||||
self.in_features))
|
||||
|
||||
self.cross_attn = CrossAttention(kv_dim, embed_dim, num_heads)
|
||||
self.cross_attn = AriaCrossAttention(config)
|
||||
|
||||
self.ln_ffn = norm_layer(embed_dim)
|
||||
self.ffn = FFN(embed_dim, ff_dim, output_dim)
|
||||
self.layer_norm = nn.LayerNorm(self.in_features)
|
||||
self.feed_forward = AriaProjectorMLP(self.in_features,
|
||||
self.hidden_features,
|
||||
self.output_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
bs = x.shape[0]
|
||||
queries = self.query.unsqueeze(0).repeat(bs, 1, 1)
|
||||
batch_size, num_patches = x.shape[0], x.shape[1]
|
||||
|
||||
query_num = self.patch_to_query_dict.get(x.shape[1], None)
|
||||
assert (query_num is not None
|
||||
), f"Query number for {x.shape[1]} patches is not provided"
|
||||
if num_patches not in self.patch_to_query_dict:
|
||||
raise KeyError(f"Number of patches {num_patches} not found in "
|
||||
"patch_to_query_dict amongst possible values "
|
||||
f"{self.patch_to_query_dict.keys()}.")
|
||||
|
||||
queries = queries[:, :query_num, :]
|
||||
query_num = self.patch_to_query_dict[num_patches]
|
||||
|
||||
queries = self.query[:query_num].unsqueeze(0).repeat(batch_size, 1, 1)
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.repeat_interleave(self.num_heads, 0)
|
||||
@@ -241,7 +139,7 @@ class AriaProjector(nn.Module):
|
||||
|
||||
attention_out = self.cross_attn(x, queries, attn_mask=attn_mask)
|
||||
|
||||
out = self.ffn(self.ln_ffn(attention_out))
|
||||
out = self.feed_forward(self.layer_norm(attention_out))
|
||||
|
||||
return out
|
||||
|
||||
@@ -278,7 +176,7 @@ class AriaFusedMoE(FusedMoE):
|
||||
param.data.copy_(loaded_weight.transpose(1, 2))
|
||||
|
||||
|
||||
class MoELayer(nn.Module):
|
||||
class AriaTextMoELayer(nn.Module):
|
||||
"""
|
||||
Mixture of Experts (MoE) Layer for the AriaMoE model.
|
||||
|
||||
@@ -289,7 +187,7 @@ class MoELayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: AriaMoELMConfig,
|
||||
config: AriaTextConfig,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -303,15 +201,16 @@ class MoELayer(nn.Module):
|
||||
num_experts=config.moe_num_experts,
|
||||
top_k=config.moe_topk,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
quant_config=quant_config,
|
||||
reduce_results=True,
|
||||
)
|
||||
self.shared_experts = LlamaMLP(
|
||||
config.hidden_size,
|
||||
config.moe_intermediate_size * config.moe_num_shared_experts,
|
||||
config.intermediate_size * config.moe_num_shared_experts,
|
||||
"silu",
|
||||
quant_config=quant_config,
|
||||
bias=config.mlp_bias,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
@@ -329,13 +228,13 @@ class MoELayer(nn.Module):
|
||||
router_output = torch.nn.functional.linear(hidden_states,
|
||||
self.router_weight)
|
||||
|
||||
shared_expert_output = self.shared_experts(hidden_states)
|
||||
sparse_expert_output = self.experts(hidden_states, router_output)
|
||||
shared_expert_output = self.shared_experts(hidden_states)
|
||||
|
||||
return sparse_expert_output + shared_expert_output
|
||||
|
||||
|
||||
class MoEDecoderLayer(LlamaDecoderLayer):
|
||||
class AriaTextDecoderLayer(LlamaDecoderLayer):
|
||||
"""
|
||||
Custom Decoder Layer for the AriaMoE model which modifies the standard
|
||||
`LlamaDecoderLayer` by replacing the traditional MLP with a Mixture of
|
||||
@@ -344,16 +243,16 @@ class MoEDecoderLayer(LlamaDecoderLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: AriaMoELMConfig,
|
||||
config: AriaTextConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(config, cache_config, quant_config, prefix)
|
||||
self.mlp = MoELayer(config, quant_config=quant_config)
|
||||
self.mlp = AriaTextMoELayer(config, quant_config=quant_config)
|
||||
|
||||
|
||||
class AriaMoELMModel(LlamaModel):
|
||||
class AriaTextModel(LlamaModel):
|
||||
"""
|
||||
Custom LlamaModel for the AriaMoE model which modifies the standard
|
||||
LlamaModel by replacing the `LlamaDecoderLayer` with `MoEDecoderLayer`.
|
||||
@@ -362,7 +261,7 @@ class AriaMoELMModel(LlamaModel):
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
layer_type=MoEDecoderLayer)
|
||||
layer_type=AriaTextDecoderLayer)
|
||||
|
||||
# Adapted from LlamaModel.load_weights with the modification of adding
|
||||
# the expert weights mapping to `stacked_params_mapping`
|
||||
@@ -434,25 +333,23 @@ class AriaMoELMModel(LlamaModel):
|
||||
return loaded_params
|
||||
|
||||
|
||||
def build_mm_projector(config: PretrainedConfig):
|
||||
return AriaProjector(
|
||||
patch_to_query_dict=config.projector_patch_to_query_dict,
|
||||
embed_dim=config.vision_config.hidden_size,
|
||||
num_heads=config.vision_config.num_attention_heads,
|
||||
kv_dim=config.vision_config.hidden_size,
|
||||
ff_dim=config.text_config.hidden_size,
|
||||
output_dim=config.text_config.hidden_size,
|
||||
)
|
||||
|
||||
|
||||
class AriaProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_hf_config(self):
|
||||
return self.ctx.get_hf_config()
|
||||
return self.ctx.get_hf_config(AriaConfig)
|
||||
|
||||
def get_vision_config(self) -> AriaVisionConfig:
|
||||
def get_vision_config(self):
|
||||
return self.get_hf_config().vision_config
|
||||
|
||||
def get_hf_processor(self):
|
||||
processor = self.ctx.get_hf_processor(AriaProcessor)
|
||||
|
||||
# Patch for https://github.com/huggingface/transformers/issues/35768
|
||||
processor.tokenizer.image_token = "<|img|>"
|
||||
processor.image_token = "<|img|>"
|
||||
|
||||
return processor
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"image": None}
|
||||
|
||||
@@ -554,10 +451,14 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config = config
|
||||
self.vision_tower = AriaVisionModel(config.vision_config)
|
||||
self.multi_modal_projector = build_mm_projector(config)
|
||||
self.vision_tower = Idefics3VisionTransformer(
|
||||
config.vision_config,
|
||||
quant_config,
|
||||
prefix=f"{prefix}.vision_tower",
|
||||
)
|
||||
self.multi_modal_projector = AriaProjector(config)
|
||||
self.vocab_size = config.text_config.vocab_size
|
||||
self.language_model = AriaMoELMModel(
|
||||
self.language_model = AriaTextModel(
|
||||
vllm_config=vllm_config.with_hf_config(config.text_config),
|
||||
prefix=maybe_prefix(prefix, "language_model.model"),
|
||||
)
|
||||
@@ -608,6 +509,22 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
pixel_mask=pixel_mask,
|
||||
)
|
||||
|
||||
def _create_patch_attention_mask(
|
||||
self, pixel_mask: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
if pixel_mask is None:
|
||||
return None
|
||||
|
||||
patches_subgrid = pixel_mask.unfold(
|
||||
dimension=1,
|
||||
size=self.vision_tower.config.patch_size,
|
||||
step=self.vision_tower.config.patch_size,
|
||||
).unfold(
|
||||
dimension=2,
|
||||
size=self.vision_tower.config.patch_size,
|
||||
step=self.vision_tower.config.patch_size,
|
||||
)
|
||||
return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool()
|
||||
|
||||
def _process_image_input(
|
||||
self, image_input: AriaImagePixelInputs
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
@@ -616,9 +533,18 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
pixel_values = image_input['pixel_values']
|
||||
pixel_mask = image_input['pixel_mask']
|
||||
|
||||
image_feature, image_attn_mask = self.vision_tower(
|
||||
pixel_values, pixel_mask=pixel_mask)
|
||||
return self.multi_modal_projector(image_feature, image_attn_mask)
|
||||
patch_attention_mask = self._create_patch_attention_mask(pixel_mask)
|
||||
|
||||
image_outputs = self.vision_tower(
|
||||
pixel_values=pixel_values,
|
||||
patch_attention_mask=patch_attention_mask,
|
||||
)
|
||||
image_attn_mask = None
|
||||
if patch_attention_mask is not None:
|
||||
flattened_mask = patch_attention_mask.flatten(1)
|
||||
image_attn_mask = torch.logical_not(flattened_mask)
|
||||
|
||||
return self.multi_modal_projector(image_outputs, image_attn_mask)
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
@@ -683,6 +609,5 @@ class AriaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
|
||||
loader = AutoWeightsLoader(self)
|
||||
loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
|
||||
|
||||
Reference in New Issue
Block a user