Refactor sliding window configuration to Transformers best practice (#21927)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-08-10 04:50:48 +01:00
committed by GitHub
parent 2a84fb422f
commit c49848396d
16 changed files with 123 additions and 231 deletions

View File

@@ -280,6 +280,17 @@ def is_encoder_decoder(config: PretrainedConfig) -> bool:
return getattr(config, "is_encoder_decoder", False)
def is_interleaved(config: PretrainedConfig) -> bool:
"""
Detect if the model with this config is used with interleaved attention.
"""
text_config = config.get_text_config()
if layer_types := getattr(text_config, "layer_types", None):
interleaved_types = {"full_attention", "sliding_attention"}
return interleaved_types.issubset(layer_types)
return False
def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig:
"""Remap config attributes to match the expected names."""
for old_attr, new_attr in _CONFIG_ATTRS_MAPPING.items():
@@ -423,6 +434,23 @@ def get_config(
raise e
config = _maybe_remap_hf_config_attrs(config)
# Phi4Flash misuses this config as list[int]. Convert it to int and add
# the layer_types list[str] to make it HF compatible
if (config.model_type == "phi4flash"):
# TODO: Remove after the following PR is merged:
# https://huggingface.co/microsoft/Phi-4-mini-flash-reasoning/discussions/6
if not hasattr(config, "layer_types"):
config.layer_types = [
"sliding_attention" if i < config.num_hidden_layers // 2
and i % 2 == 1 else "full_attention"
for i in range(config.num_hidden_layers)
]
# TODO: Remove after the following PR is merged:
# https://huggingface.co/microsoft/Phi-4-mini-flash-reasoning/discussions/7
if isinstance(config.sliding_window, list):
config.sliding_window = next(
filter(None, config.sliding_window), None)
elif config_format == ConfigFormat.MISTRAL:
# This function loads a params.json config which
# should be used when loading models in mistral format
@@ -434,6 +462,18 @@ def get_config(
config_dict["max_position_embeddings"] = max_position_embeddings
config = adapt_config_dict(config_dict)
# Mistral configs may define sliding_window as list[int]. Convert it
# to int and add the layer_types list[str] to make it HF compatible
if ((sliding_window := getattr(config, "sliding_window", None))
and isinstance(sliding_window, list)):
pattern_repeats = config.num_hidden_layers // len(sliding_window)
layer_types = sliding_window * pattern_repeats
config.layer_types = [
"full_attention" if layer_type is None else "sliding_attention"
for layer_type in layer_types
]
config.sliding_window = next(filter(None, sliding_window), None)
else:
supported_formats = [
fmt.value for fmt in ConfigFormat if fmt != ConfigFormat.AUTO