fix minicpmo4.5: fix attn_mask in vit attn && fix resampler pos_emb i… (#34127)

Signed-off-by: tc-mb <caitianchi@modelbest.cn>
Co-authored-by: hezhihui <hezhihui@modelbest.cn>
This commit is contained in:
tc-mb
2026-03-05 01:46:17 +08:00
committed by GitHub
parent d25c1ec3c9
commit bfdb512f11
2 changed files with 69 additions and 7 deletions

View File

@@ -22,6 +22,7 @@ from collections.abc import Iterable
import torch
from torch import nn
from torch.nn import functional as F
from transformers.models.idefics2.configuration_idefics2 import (
Idefics2Config,
Idefics2VisionConfig,
@@ -172,14 +173,41 @@ class Idefics2VisionAttention(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(
hidden_states
) # batch_size, q_len, 3 * num_heads_per_partition * head_dim
query_states, key_states, value_states = qkv.chunk(3, dim=-1)
# If attention_mask is provided, prefer Torch SDPA so the mask is
# correctly applied (aligns with HuggingFace NaViT SigLIP behavior).
if attention_mask is None:
# Use unified MMEncoderAttention implementation
out = self.attn(query_states, key_states, value_states)
else:
bsz, q_len = query_states.size()[:2]
kv_len = key_states.size(1)
query = query_states.view(
bsz, q_len, self.num_heads_per_partition, self.head_dim
).transpose(1, 2)
key = key_states.view(
bsz, kv_len, self.num_heads_per_partition, self.head_dim
).transpose(1, 2)
value = value_states.view(
bsz, kv_len, self.num_heads_per_partition, self.head_dim
).transpose(1, 2)
out = F.scaled_dot_product_attention(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
scale=self.scale,
)
out = out.transpose(1, 2).reshape(bsz, q_len, -1)
attn_output, _ = self.out_proj(out)
return attn_output
@@ -245,6 +273,7 @@ class Idefics2EncoderLayer(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Args:
@@ -254,7 +283,7 @@ class Idefics2EncoderLayer(nn.Module):
"""
residual = hidden_states
hidden_states = self.layer_norm1(hidden_states)
hidden_states = self.self_attn(hidden_states)
hidden_states = self.self_attn(hidden_states, attention_mask=attention_mask)
hidden_states += residual
residual = hidden_states
hidden_states = self.layer_norm2(hidden_states)
@@ -304,6 +333,7 @@ class Idefics2Encoder(nn.Module):
def forward(
self,
inputs_embeds: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
r"""
Args:
@@ -316,7 +346,7 @@ class Idefics2Encoder(nn.Module):
"""
hidden_states = inputs_embeds
for encoder_layer in self.layers:
layer_outputs = encoder_layer(hidden_states)
layer_outputs = encoder_layer(hidden_states, attention_mask=attention_mask)
hidden_states = layer_outputs
return hidden_states
@@ -370,15 +400,47 @@ class Idefics2VisionTransformer(nn.Module):
patch_attention_mask: torch.BoolTensor | None = None,
tgt_sizes: torch.IntTensor | None = None,
) -> torch.Tensor:
batch_size = pixel_values.size(0)
if patch_attention_mask is None:
# No mask provided - create default all-ones mask for embeddings
# and skip attention masking (no padding to mask)
patch_attention_mask = torch.ones(
size=(
batch_size,
pixel_values.size(2) // self.config.patch_size,
pixel_values.size(3) // self.config.patch_size,
),
dtype=torch.bool,
device=pixel_values.device,
)
flat_patch_mask = None
else:
flat_patch_mask = patch_attention_mask.view(batch_size, -1)
hidden_states = self.embeddings(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask,
tgt_sizes=tgt_sizes,
)
# Align with HuggingFace NaViT SigLIP in MiniCPMV/O:
# - if patch_attention_mask was None, skip attention masking
# - if any padding exists, create an additive 4D mask and pass it
# to attention; else skip mask for performance.
if flat_patch_mask is None or not torch.any(~flat_patch_mask):
attention_mask = None
else:
# Additive mask: masked positions receive a large negative value.
# Shape: (B, 1, 1, L) broadcastable to (B, H, Q, K).
min_val = torch.finfo(hidden_states.dtype).min
attention_mask = (~flat_patch_mask).to(dtype=hidden_states.dtype) * min_val
attention_mask = attention_mask[:, None, None, :]
if self.use_data_parallel:
encoder_outputs = run_dp_sharded_vision_model(hidden_states, self.encoder)
else:
encoder_outputs = self.encoder(hidden_states)
encoder_outputs = self.encoder(hidden_states, attention_mask=attention_mask)
last_hidden_state = self.post_layernorm(encoder_outputs)
return last_hidden_state

View File

@@ -387,8 +387,8 @@ class Resampler4_5(Resampler2_5):
pos_embed_2d, batch_first=True, padding_value=0.0
).permute(1, 0, 2) # BLD => L * B * D
k = x
v = x + pos_embed_2d
k = x + pos_embed_2d
v = x
if pos_embed_temporal:
k += torch.stack(pos_embed_temporal, dim=0)
bs = len(temporal_ids)