[Bugfix] Revert custom attention mask for gemma3-mm (#28995)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -596,7 +596,7 @@ class Gemma3ForConditionalGeneration(
|
||||
def get_language_model(self) -> torch.nn.Module:
|
||||
return self.language_model
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return []
|
||||
@@ -644,142 +644,6 @@ class Gemma3ForConditionalGeneration(
|
||||
|
||||
return hidden_states
|
||||
|
||||
def generate_attention_masks(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
mask_dtype: torch.dtype,
|
||||
) -> dict[str, Any]:
|
||||
"""Generate custom attention masks for Gemma3 multimodal inputs.
|
||||
|
||||
This is called by V1 engine's gpu_model_runner during preprocessing
|
||||
to generate attention masks that allow bidirectional attention between
|
||||
image tokens while maintaining causal attention for text.
|
||||
"""
|
||||
# NOTE(woosuk): Here, we distinguish the sequences by the position id 0.
|
||||
# This is a HACK. Fix this.
|
||||
start_indices = (positions == 0).cpu().nonzero()
|
||||
num_seqs = len(start_indices)
|
||||
seq_lens = []
|
||||
for i in range(num_seqs):
|
||||
start_idx = start_indices[i]
|
||||
end_idx = start_indices[i + 1] if i < num_seqs - 1 else len(input_ids)
|
||||
seq_lens.append(end_idx - start_idx)
|
||||
|
||||
global_attn_masks = []
|
||||
local_attn_masks = []
|
||||
start_idx = 0
|
||||
for seq_idx, seq_len in enumerate(seq_lens):
|
||||
end_idx = start_idx + seq_len
|
||||
input_token_ids = input_ids[start_idx:end_idx]
|
||||
|
||||
# Find image token positions
|
||||
img_pos = input_token_ids == self.config.image_token_index
|
||||
|
||||
start_idx = end_idx
|
||||
|
||||
# Create a global causal mask
|
||||
global_attn_mask = torch.empty(
|
||||
1,
|
||||
1,
|
||||
seq_len,
|
||||
seq_len,
|
||||
dtype=mask_dtype,
|
||||
device=input_ids.device,
|
||||
)
|
||||
global_attn_mask.fill_(float("-inf"))
|
||||
# Fill the lower triangle with 0 (causal attention)
|
||||
global_attn_mask = global_attn_mask.triu(diagonal=1)
|
||||
|
||||
# Enable bidirectional attention between image tokens
|
||||
img_mask = torch.zeros_like(global_attn_mask)
|
||||
img_mask[:, :, :, img_pos] += 1
|
||||
img_mask[:, :, img_pos, :] += 1
|
||||
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
|
||||
global_attn_masks.append(global_attn_mask)
|
||||
|
||||
# GGUF compatibility: config might be Gemma3TextConfig directly
|
||||
text_config = getattr(self.config, "text_config", self.config)
|
||||
sliding_window = text_config.sliding_window
|
||||
if sliding_window is not None:
|
||||
# Create a local causal mask with sliding window (1024)
|
||||
local_attn_mask = torch.ones_like(global_attn_mask)
|
||||
local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window)
|
||||
local_attn_mask = torch.where(
|
||||
local_attn_mask == 0, global_attn_mask, float("-inf")
|
||||
)
|
||||
local_attn_masks.append(local_attn_mask)
|
||||
|
||||
return {
|
||||
"has_images": True,
|
||||
"seq_lens": seq_lens,
|
||||
"global_attn_masks": global_attn_masks,
|
||||
"local_attn_masks": local_attn_masks,
|
||||
}
|
||||
|
||||
def prepare_attn_masks(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
mask_dtype: torch.dtype,
|
||||
**kwargs,
|
||||
):
|
||||
kwargs["has_images"] = True
|
||||
# NOTE(woosuk): Here, we distinguish the sequences by the position id 0.
|
||||
# This is a HACK. Fix this.
|
||||
start_indices = (positions == 0).cpu().nonzero()
|
||||
num_seqs = len(start_indices)
|
||||
seq_lens = []
|
||||
for i in range(num_seqs):
|
||||
start_idx = start_indices[i].item()
|
||||
if i < num_seqs - 1:
|
||||
end_idx = start_indices[i + 1].item()
|
||||
else:
|
||||
end_idx = len(input_ids)
|
||||
seq_lens.append(end_idx - start_idx)
|
||||
kwargs["seq_lens"] = seq_lens
|
||||
|
||||
global_attn_masks = []
|
||||
local_attn_masks = []
|
||||
start_idx = 0
|
||||
for seq_len in seq_lens:
|
||||
end_idx = start_idx + seq_len
|
||||
input_token_ids = input_ids[start_idx:end_idx]
|
||||
start_idx = end_idx
|
||||
# Create a global causal mask.
|
||||
global_attn_mask = torch.empty(
|
||||
1,
|
||||
1,
|
||||
seq_len,
|
||||
seq_len,
|
||||
dtype=mask_dtype,
|
||||
device=input_ids.device,
|
||||
)
|
||||
global_attn_mask.fill_(float("-inf"))
|
||||
# Fill the lower triangle with 0.
|
||||
global_attn_mask = global_attn_mask.triu(diagonal=1)
|
||||
|
||||
# Consider the bidirectional attention between image tokens.
|
||||
img_mask = torch.zeros_like(global_attn_mask)
|
||||
img_pos = input_token_ids == self.config.image_token_index
|
||||
img_mask[:, :, :, img_pos] += 1
|
||||
img_mask[:, :, img_pos, :] += 1
|
||||
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
|
||||
global_attn_masks.append(global_attn_mask)
|
||||
|
||||
sliding_window = self.config.text_config.sliding_window
|
||||
if sliding_window is not None:
|
||||
# Create a local causal mask with sliding window (1024).
|
||||
local_attn_mask = torch.ones_like(global_attn_mask)
|
||||
local_attn_mask = torch.tril(local_attn_mask, diagonal=-sliding_window)
|
||||
local_attn_mask = torch.where(
|
||||
local_attn_mask == 0, global_attn_mask, float("-inf")
|
||||
)
|
||||
local_attn_masks.append(local_attn_mask)
|
||||
kwargs["global_attn_masks"] = global_attn_masks
|
||||
kwargs["local_attn_masks"] = local_attn_masks
|
||||
return kwargs
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user