[Bugfix] gemma[2,3] interleaved attention when sliding window is disabled (#17180)
Signed-off-by: Chen Zhang <zhangch99@outlook.com>
This commit is contained in:
@@ -478,7 +478,8 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
self.multimodal_config = multimodal_config
|
||||
self.sliding_window = config.text_config.interleaved_sliding_window
|
||||
self.sliding_window = getattr(config.text_config,
|
||||
"interleaved_sliding_window", None)
|
||||
|
||||
self.vision_tower = SiglipVisionModel(config.vision_config,
|
||||
quant_config,
|
||||
@@ -680,13 +681,14 @@ class Gemma3ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP,
|
||||
global_attn_mask = torch.where(img_mask == 2, 0, global_attn_mask)
|
||||
global_attn_masks.append(global_attn_mask)
|
||||
|
||||
# Create a local causal mask with sliding window (1024).
|
||||
local_attn_mask = torch.ones_like(global_attn_mask)
|
||||
local_attn_mask = torch.tril(local_attn_mask,
|
||||
diagonal=-self.sliding_window)
|
||||
local_attn_mask = torch.where(local_attn_mask == 0,
|
||||
global_attn_mask, float("-inf"))
|
||||
local_attn_masks.append(local_attn_mask)
|
||||
if self.sliding_window is not None:
|
||||
# Create a local causal mask with sliding window (1024).
|
||||
local_attn_mask = torch.ones_like(global_attn_mask)
|
||||
local_attn_mask = torch.tril(local_attn_mask,
|
||||
diagonal=-self.sliding_window)
|
||||
local_attn_mask = torch.where(local_attn_mask == 0,
|
||||
global_attn_mask, float("-inf"))
|
||||
local_attn_masks.append(local_attn_mask)
|
||||
kwargs["global_attn_masks"] = global_attn_masks
|
||||
kwargs["local_attn_masks"] = local_attn_masks
|
||||
return kwargs
|
||||
|
||||
Reference in New Issue
Block a user