[BugFix] Fix interleaved sliding window not set for Gemma3n (#21863)
Signed-off-by: Yong Hoon Shin <yhshin@meta.com>
This commit is contained in:
@@ -723,11 +723,16 @@ class ModelConfig:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Workaround for Gemma 2 which uses interleaved sliding window
|
# Workaround for Gemma 2 which uses interleaved sliding window
|
||||||
# attention, but it's not specified in its config. TODO: remove this
|
# attention, but it's not specified in its config.
|
||||||
# when Gemma 2 is fixed in Transformers.
|
# TODO: remove this when Gemma 2 config updated in HuggingFace.
|
||||||
if self.hf_text_config.model_type == "gemma2":
|
if self.hf_text_config.model_type == "gemma2":
|
||||||
self.hf_text_config.sliding_window_pattern = 2
|
self.hf_text_config.sliding_window_pattern = 2
|
||||||
|
|
||||||
|
# TODO: remove this when Gemma 3n config updated in HuggingFace.
|
||||||
|
if self.hf_text_config.model_type == "gemma3n_text":
|
||||||
|
# 4 sliding window attention followed by 1 full attention
|
||||||
|
self.hf_text_config.sliding_window_pattern = "LLLLG"
|
||||||
|
|
||||||
sliding_window = getattr(self.hf_text_config, "sliding_window", None)
|
sliding_window = getattr(self.hf_text_config, "sliding_window", None)
|
||||||
sliding_window_pattern = getattr(self.hf_text_config,
|
sliding_window_pattern = getattr(self.hf_text_config,
|
||||||
"sliding_window_pattern", None)
|
"sliding_window_pattern", None)
|
||||||
|
|||||||
@@ -297,8 +297,13 @@ class Gemma3nAttention(nn.Module):
|
|||||||
has_weight=False)
|
has_weight=False)
|
||||||
|
|
||||||
layer_idx = extract_layer_index(prefix)
|
layer_idx = extract_layer_index(prefix)
|
||||||
if config.layer_types[layer_idx] == "sliding_attention":
|
|
||||||
self.sliding_window = config.sliding_window
|
is_sliding_window = (
|
||||||
|
getattr(config, "interleaved_sliding_window", None) is not None
|
||||||
|
and config.layer_types[layer_idx] == "sliding_attention")
|
||||||
|
|
||||||
|
if is_sliding_window:
|
||||||
|
self.sliding_window = config.interleaved_sliding_window
|
||||||
rope_theta = config.rope_local_base_freq
|
rope_theta = config.rope_local_base_freq
|
||||||
rope_scaling = {"rope_type": "default"}
|
rope_scaling = {"rope_type": "default"}
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user