fix minicpmo4.5: fix attn_mask in vit attn && fix resampler pos_emb i… (#34127)
Signed-off-by: tc-mb <caitianchi@modelbest.cn> Co-authored-by: hezhihui <hezhihui@modelbest.cn>
This commit is contained in:
@@ -22,6 +22,7 @@ from collections.abc import Iterable
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.nn import functional as F
|
||||||
from transformers.models.idefics2.configuration_idefics2 import (
|
from transformers.models.idefics2.configuration_idefics2 import (
|
||||||
Idefics2Config,
|
Idefics2Config,
|
||||||
Idefics2VisionConfig,
|
Idefics2VisionConfig,
|
||||||
@@ -172,14 +173,41 @@ class Idefics2VisionAttention(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
qkv, _ = self.qkv_proj(
|
qkv, _ = self.qkv_proj(
|
||||||
hidden_states
|
hidden_states
|
||||||
) # batch_size, q_len, 3 * num_heads_per_partition * head_dim
|
) # batch_size, q_len, 3 * num_heads_per_partition * head_dim
|
||||||
query_states, key_states, value_states = qkv.chunk(3, dim=-1)
|
query_states, key_states, value_states = qkv.chunk(3, dim=-1)
|
||||||
|
|
||||||
# Use unified MMEncoderAttention implementation
|
# If attention_mask is provided, prefer Torch SDPA so the mask is
|
||||||
out = self.attn(query_states, key_states, value_states)
|
# correctly applied (aligns with HuggingFace NaViT SigLIP behavior).
|
||||||
|
if attention_mask is None:
|
||||||
|
# Use unified MMEncoderAttention implementation
|
||||||
|
out = self.attn(query_states, key_states, value_states)
|
||||||
|
else:
|
||||||
|
bsz, q_len = query_states.size()[:2]
|
||||||
|
kv_len = key_states.size(1)
|
||||||
|
|
||||||
|
query = query_states.view(
|
||||||
|
bsz, q_len, self.num_heads_per_partition, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
key = key_states.view(
|
||||||
|
bsz, kv_len, self.num_heads_per_partition, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
value = value_states.view(
|
||||||
|
bsz, kv_len, self.num_heads_per_partition, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
out = F.scaled_dot_product_attention(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
attn_mask=attention_mask,
|
||||||
|
dropout_p=0.0,
|
||||||
|
scale=self.scale,
|
||||||
|
)
|
||||||
|
out = out.transpose(1, 2).reshape(bsz, q_len, -1)
|
||||||
attn_output, _ = self.out_proj(out)
|
attn_output, _ = self.out_proj(out)
|
||||||
return attn_output
|
return attn_output
|
||||||
|
|
||||||
@@ -245,6 +273,7 @@ class Idefics2EncoderLayer(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -254,7 +283,7 @@ class Idefics2EncoderLayer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.layer_norm1(hidden_states)
|
hidden_states = self.layer_norm1(hidden_states)
|
||||||
hidden_states = self.self_attn(hidden_states)
|
hidden_states = self.self_attn(hidden_states, attention_mask=attention_mask)
|
||||||
hidden_states += residual
|
hidden_states += residual
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.layer_norm2(hidden_states)
|
hidden_states = self.layer_norm2(hidden_states)
|
||||||
@@ -304,6 +333,7 @@ class Idefics2Encoder(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
inputs_embeds: torch.Tensor,
|
inputs_embeds: torch.Tensor,
|
||||||
|
attention_mask: torch.Tensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
r"""
|
r"""
|
||||||
Args:
|
Args:
|
||||||
@@ -316,7 +346,7 @@ class Idefics2Encoder(nn.Module):
|
|||||||
"""
|
"""
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
for encoder_layer in self.layers:
|
for encoder_layer in self.layers:
|
||||||
layer_outputs = encoder_layer(hidden_states)
|
layer_outputs = encoder_layer(hidden_states, attention_mask=attention_mask)
|
||||||
hidden_states = layer_outputs
|
hidden_states = layer_outputs
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@@ -370,15 +400,47 @@ class Idefics2VisionTransformer(nn.Module):
|
|||||||
patch_attention_mask: torch.BoolTensor | None = None,
|
patch_attention_mask: torch.BoolTensor | None = None,
|
||||||
tgt_sizes: torch.IntTensor | None = None,
|
tgt_sizes: torch.IntTensor | None = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
batch_size = pixel_values.size(0)
|
||||||
|
|
||||||
|
if patch_attention_mask is None:
|
||||||
|
# No mask provided - create default all-ones mask for embeddings
|
||||||
|
# and skip attention masking (no padding to mask)
|
||||||
|
patch_attention_mask = torch.ones(
|
||||||
|
size=(
|
||||||
|
batch_size,
|
||||||
|
pixel_values.size(2) // self.config.patch_size,
|
||||||
|
pixel_values.size(3) // self.config.patch_size,
|
||||||
|
),
|
||||||
|
dtype=torch.bool,
|
||||||
|
device=pixel_values.device,
|
||||||
|
)
|
||||||
|
flat_patch_mask = None
|
||||||
|
else:
|
||||||
|
flat_patch_mask = patch_attention_mask.view(batch_size, -1)
|
||||||
|
|
||||||
hidden_states = self.embeddings(
|
hidden_states = self.embeddings(
|
||||||
pixel_values=pixel_values,
|
pixel_values=pixel_values,
|
||||||
patch_attention_mask=patch_attention_mask,
|
patch_attention_mask=patch_attention_mask,
|
||||||
tgt_sizes=tgt_sizes,
|
tgt_sizes=tgt_sizes,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Align with HuggingFace NaViT SigLIP in MiniCPMV/O:
|
||||||
|
# - if patch_attention_mask was None, skip attention masking
|
||||||
|
# - if any padding exists, create an additive 4D mask and pass it
|
||||||
|
# to attention; else skip mask for performance.
|
||||||
|
if flat_patch_mask is None or not torch.any(~flat_patch_mask):
|
||||||
|
attention_mask = None
|
||||||
|
else:
|
||||||
|
# Additive mask: masked positions receive a large negative value.
|
||||||
|
# Shape: (B, 1, 1, L) broadcastable to (B, H, Q, K).
|
||||||
|
min_val = torch.finfo(hidden_states.dtype).min
|
||||||
|
attention_mask = (~flat_patch_mask).to(dtype=hidden_states.dtype) * min_val
|
||||||
|
attention_mask = attention_mask[:, None, None, :]
|
||||||
|
|
||||||
if self.use_data_parallel:
|
if self.use_data_parallel:
|
||||||
encoder_outputs = run_dp_sharded_vision_model(hidden_states, self.encoder)
|
encoder_outputs = run_dp_sharded_vision_model(hidden_states, self.encoder)
|
||||||
else:
|
else:
|
||||||
encoder_outputs = self.encoder(hidden_states)
|
encoder_outputs = self.encoder(hidden_states, attention_mask=attention_mask)
|
||||||
last_hidden_state = self.post_layernorm(encoder_outputs)
|
last_hidden_state = self.post_layernorm(encoder_outputs)
|
||||||
return last_hidden_state
|
return last_hidden_state
|
||||||
|
|
||||||
|
|||||||
@@ -387,8 +387,8 @@ class Resampler4_5(Resampler2_5):
|
|||||||
pos_embed_2d, batch_first=True, padding_value=0.0
|
pos_embed_2d, batch_first=True, padding_value=0.0
|
||||||
).permute(1, 0, 2) # BLD => L * B * D
|
).permute(1, 0, 2) # BLD => L * B * D
|
||||||
|
|
||||||
k = x
|
k = x + pos_embed_2d
|
||||||
v = x + pos_embed_2d
|
v = x
|
||||||
if pos_embed_temporal:
|
if pos_embed_temporal:
|
||||||
k += torch.stack(pos_embed_temporal, dim=0)
|
k += torch.stack(pos_embed_temporal, dim=0)
|
||||||
bs = len(temporal_ids)
|
bs = len(temporal_ids)
|
||||||
|
|||||||
Reference in New Issue
Block a user