Refactor sliding window configuration to Transformers best practice (#21927)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-08-10 04:50:48 +01:00
committed by GitHub
parent 2a84fb422f
commit c49848396d
16 changed files with 123 additions and 231 deletions

View File

@@ -502,8 +502,6 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
self.config = config
self.quant_config = quant_config
self.multimodal_config = multimodal_config
self.sliding_window = getattr(config.text_config,
"interleaved_sliding_window", None)
self.vision_tower = SiglipVisionModel(config.vision_config,
quant_config,
@@ -690,11 +688,11 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
global_attn_masks.append(global_attn_mask)
if self.sliding_window is not None:
if (sliding_window := self.config.sliding_window) is not None:
# Create a local causal mask with sliding window (1024).
local_attn_mask = torch.ones_like(global_attn_mask)
local_attn_mask = torch.tril(local_attn_mask,
diagonal=-self.sliding_window)
diagonal=-sliding_window)
local_attn_mask = torch.where(local_attn_mask == 0,
global_attn_mask, float("-inf"))
local_attn_masks.append(local_attn_mask)