fix minicpmo4.5: fix attn_mask in vit attn && fix resampler pos_emb i… (#34127)
Signed-off-by: tc-mb <caitianchi@modelbest.cn> Co-authored-by: hezhihui <hezhihui@modelbest.cn>
This commit is contained in:
@@ -22,6 +22,7 @@ from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from transformers.models.idefics2.configuration_idefics2 import (
|
||||
Idefics2Config,
|
||||
Idefics2VisionConfig,
|
||||
@@ -172,14 +173,41 @@ class Idefics2VisionAttention(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(
|
||||
hidden_states
|
||||
) # batch_size, q_len, 3 * num_heads_per_partition * head_dim
|
||||
query_states, key_states, value_states = qkv.chunk(3, dim=-1)
|
||||
|
||||
# Use unified MMEncoderAttention implementation
|
||||
out = self.attn(query_states, key_states, value_states)
|
||||
# If attention_mask is provided, prefer Torch SDPA so the mask is
|
||||
# correctly applied (aligns with HuggingFace NaViT SigLIP behavior).
|
||||
if attention_mask is None:
|
||||
# Use unified MMEncoderAttention implementation
|
||||
out = self.attn(query_states, key_states, value_states)
|
||||
else:
|
||||
bsz, q_len = query_states.size()[:2]
|
||||
kv_len = key_states.size(1)
|
||||
|
||||
query = query_states.view(
|
||||
bsz, q_len, self.num_heads_per_partition, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key = key_states.view(
|
||||
bsz, kv_len, self.num_heads_per_partition, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value = value_states.view(
|
||||
bsz, kv_len, self.num_heads_per_partition, self.head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
out = F.scaled_dot_product_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
scale=self.scale,
|
||||
)
|
||||
out = out.transpose(1, 2).reshape(bsz, q_len, -1)
|
||||
attn_output, _ = self.out_proj(out)
|
||||
return attn_output
|
||||
|
||||
@@ -245,6 +273,7 @@ class Idefics2EncoderLayer(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
@@ -254,7 +283,7 @@ class Idefics2EncoderLayer(nn.Module):
|
||||
"""
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm1(hidden_states)
|
||||
hidden_states = self.self_attn(hidden_states)
|
||||
hidden_states = self.self_attn(hidden_states, attention_mask=attention_mask)
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
hidden_states = self.layer_norm2(hidden_states)
|
||||
@@ -304,6 +333,7 @@ class Idefics2Encoder(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
@@ -316,7 +346,7 @@ class Idefics2Encoder(nn.Module):
|
||||
"""
|
||||
hidden_states = inputs_embeds
|
||||
for encoder_layer in self.layers:
|
||||
layer_outputs = encoder_layer(hidden_states)
|
||||
layer_outputs = encoder_layer(hidden_states, attention_mask=attention_mask)
|
||||
hidden_states = layer_outputs
|
||||
return hidden_states
|
||||
|
||||
@@ -370,15 +400,47 @@ class Idefics2VisionTransformer(nn.Module):
|
||||
patch_attention_mask: torch.BoolTensor | None = None,
|
||||
tgt_sizes: torch.IntTensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size = pixel_values.size(0)
|
||||
|
||||
if patch_attention_mask is None:
|
||||
# No mask provided - create default all-ones mask for embeddings
|
||||
# and skip attention masking (no padding to mask)
|
||||
patch_attention_mask = torch.ones(
|
||||
size=(
|
||||
batch_size,
|
||||
pixel_values.size(2) // self.config.patch_size,
|
||||
pixel_values.size(3) // self.config.patch_size,
|
||||
),
|
||||
dtype=torch.bool,
|
||||
device=pixel_values.device,
|
||||
)
|
||||
flat_patch_mask = None
|
||||
else:
|
||||
flat_patch_mask = patch_attention_mask.view(batch_size, -1)
|
||||
|
||||
hidden_states = self.embeddings(
|
||||
pixel_values=pixel_values,
|
||||
patch_attention_mask=patch_attention_mask,
|
||||
tgt_sizes=tgt_sizes,
|
||||
)
|
||||
|
||||
# Align with HuggingFace NaViT SigLIP in MiniCPMV/O:
|
||||
# - if patch_attention_mask was None, skip attention masking
|
||||
# - if any padding exists, create an additive 4D mask and pass it
|
||||
# to attention; else skip mask for performance.
|
||||
if flat_patch_mask is None or not torch.any(~flat_patch_mask):
|
||||
attention_mask = None
|
||||
else:
|
||||
# Additive mask: masked positions receive a large negative value.
|
||||
# Shape: (B, 1, 1, L) broadcastable to (B, H, Q, K).
|
||||
min_val = torch.finfo(hidden_states.dtype).min
|
||||
attention_mask = (~flat_patch_mask).to(dtype=hidden_states.dtype) * min_val
|
||||
attention_mask = attention_mask[:, None, None, :]
|
||||
|
||||
if self.use_data_parallel:
|
||||
encoder_outputs = run_dp_sharded_vision_model(hidden_states, self.encoder)
|
||||
else:
|
||||
encoder_outputs = self.encoder(hidden_states)
|
||||
encoder_outputs = self.encoder(hidden_states, attention_mask=attention_mask)
|
||||
last_hidden_state = self.post_layernorm(encoder_outputs)
|
||||
return last_hidden_state
|
||||
|
||||
|
||||
@@ -387,8 +387,8 @@ class Resampler4_5(Resampler2_5):
|
||||
pos_embed_2d, batch_first=True, padding_value=0.0
|
||||
).permute(1, 0, 2) # BLD => L * B * D
|
||||
|
||||
k = x
|
||||
v = x + pos_embed_2d
|
||||
k = x + pos_embed_2d
|
||||
v = x
|
||||
if pos_embed_temporal:
|
||||
k += torch.stack(pos_embed_temporal, dim=0)
|
||||
bs = len(temporal_ids)
|
||||
|
||||
Reference in New Issue
Block a user