[Bugfix] gemma[2,3] interleaved attention when sliding window is disabled (#17180)

Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
Chen Zhang
2025-04-26 10:53:51 +08:00
committed by GitHub
parent c53e0730cb
commit 8de2901fea
3 changed files with 15 additions and 11 deletions

View File

@@ -478,7 +478,8 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
self.config = config
self.quant_config = quant_config
self.multimodal_config = multimodal_config
self.sliding_window = config.text_config.interleaved_sliding_window
self.sliding_window = getattr(config.text_config,
"interleaved_sliding_window", None)
self.vision_tower = SiglipVisionModel(config.vision_config,
quant_config,
@@ -680,13 +681,14 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
global_attn_masks.append(global_attn_mask)
# Create a local causal mask with sliding window (1024).
local_attn_mask = torch.ones_like(global_attn_mask)
local_attn_mask = torch.tril(local_attn_mask,
diagonal=-self.sliding_window)
local_attn_mask = torch.where(local_attn_mask == 0,
global_attn_mask, float("-inf"))
local_attn_masks.append(local_attn_mask)
if self.sliding_window is not None:
# Create a local causal mask with sliding window (1024).
local_attn_mask = torch.ones_like(global_attn_mask)
local_attn_mask = torch.tril(local_attn_mask,
diagonal=-self.sliding_window)
local_attn_mask = torch.where(local_attn_mask == 0,
global_attn_mask, float("-inf"))
local_attn_masks.append(local_attn_mask)
kwargs["global_attn_masks"] = global_attn_masks
kwargs["local_attn_masks"] = local_attn_masks
return kwargs