[Bugfix] Revert custom attention mask for gemma3-mm (#28995)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Isotr0py
2025-11-20 13:23:22 +08:00
committed by GitHub
parent fe25772aa9
commit 64192d5624
4 changed files with 1 additions and 172 deletions

View File

@@ -596,7 +596,7 @@ class Gemma3ForConditionalGeneration(
def get_language_model(self) -> torch.nn.Module:
return self.language_model
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is None:
return []
@@ -644,142 +644,6 @@ class Gemma3ForConditionalGeneration(
return hidden_states
def generate_attention_masks(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
mask_dtype: torch.dtype,
) -> dict[str, Any]:
"""Generate custom attention masks for Gemma3 multimodal inputs.
This is called by V1 engine's gpu_model_runner during preprocessing
to generate attention masks that allow bidirectional attention between
image tokens while maintaining causal attention for text.
"""
# NOTE(woosuk): Here, we distinguish the sequences by the position id 0.
# This is a HACK. Fix this.
start_indices = (positions == 0).cpu().nonzero()
num_seqs = len(start_indices)
seq_lens = []
for i in range(num_seqs):
start_idx = start_indices[i]
end_idx = start_indices[i + 1] if i < num_seqs - 1 else len(input_ids)
seq_lens.append(end_idx - start_idx)
global_attn_masks = []
local_attn_masks = []
start_idx = 0
for seq_idx, seq_len in enumerate(seq_lens):
end_idx = start_idx + seq_len
input_token_ids = input_ids[start_idx:end_idx]
# Find image token positions
img_pos = input_token_ids == self.config.image_token_index
start_idx = end_idx
# Create a global causal mask
global_attn_mask = torch.empty(
1,
1,
seq_len,
seq_len,
dtype=mask_dtype,
device=input_ids.device,
)
global_attn_mask.fill_(float("-inf"))
# Fill the lower triangle with 0 (causal attention)
global_attn_mask = global_attn_mask.triu(diagonal=1)
# Enable bidirectional attention between image tokens
img_mask = torch.zeros_like(global_attn_mask)
img_mask[:, :, :, img_pos] += 1
img_mask[:, :, img_pos, :] += 1
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
global_attn_masks.append(global_attn_mask)
# GGUF compatibility: config might be Gemma3TextConfig directly
text_config = getattr(self.config, "text_config", self.config)
sliding_window = text_config.sliding_window
if sliding_window is not None:
# Create a local causal mask with sliding window (1024)
local_attn_mask = torch.ones_like(global_attn_mask)
local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window)
local_attn_mask = torch.where(
local_attn_mask == 0, global_attn_mask, float("-inf")
)
local_attn_masks.append(local_attn_mask)
return {
"has_images": True,
"seq_lens": seq_lens,
"global_attn_masks": global_attn_masks,
"local_attn_masks": local_attn_masks,
}
def prepare_attn_masks(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
mask_dtype: torch.dtype,
**kwargs,
):
kwargs["has_images"] = True
# NOTE(woosuk): Here, we distinguish the sequences by the position id 0.
# This is a HACK. Fix this.
start_indices = (positions == 0).cpu().nonzero()
num_seqs = len(start_indices)
seq_lens = []
for i in range(num_seqs):
start_idx = start_indices[i].item()
if i < num_seqs - 1:
end_idx = start_indices[i + 1].item()
else:
end_idx = len(input_ids)
seq_lens.append(end_idx - start_idx)
kwargs["seq_lens"] = seq_lens
global_attn_masks = []
local_attn_masks = []
start_idx = 0
for seq_len in seq_lens:
end_idx = start_idx + seq_len
input_token_ids = input_ids[start_idx:end_idx]
start_idx = end_idx
# Create a global causal mask.
global_attn_mask = torch.empty(
1,
1,
seq_len,
seq_len,
dtype=mask_dtype,
device=input_ids.device,
)
global_attn_mask.fill_(float("-inf"))
# Fill the lower triangle with 0.
global_attn_mask = global_attn_mask.triu(diagonal=1)
# Consider the bidirectional attention between image tokens.
img_mask = torch.zeros_like(global_attn_mask)
img_pos = input_token_ids == self.config.image_token_index
img_mask[:, :, :, img_pos] += 1
img_mask[:, :, img_pos, :] += 1
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
global_attn_masks.append(global_attn_mask)
sliding_window = self.config.text_config.sliding_window
if sliding_window is not None:
# Create a local causal mask with sliding window (1024).
local_attn_mask = torch.ones_like(global_attn_mask)
local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window)
local_attn_mask = torch.where(
local_attn_mask == 0, global_attn_mask, float("-inf")
)
local_attn_masks.append(local_attn_mask)
kwargs["global_attn_masks"] = global_attn_masks
kwargs["local_attn_masks"] = local_attn_masks
return kwargs
def compute_logits(
self,
hidden_states: torch.Tensor,