Refactor sliding window configuration to Transformers best practice (#21927)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -280,6 +280,17 @@ def is_encoder_decoder(config: PretrainedConfig) -> bool:
|
||||
return getattr(config, "is_encoder_decoder", False)
|
||||
|
||||
|
||||
def is_interleaved(config: PretrainedConfig) -> bool:
|
||||
"""
|
||||
Detect if the model with this config is used with interleaved attention.
|
||||
"""
|
||||
text_config = config.get_text_config()
|
||||
if layer_types := getattr(text_config, "layer_types", None):
|
||||
interleaved_types = {"full_attention", "sliding_attention"}
|
||||
return interleaved_types.issubset(layer_types)
|
||||
return False
|
||||
|
||||
|
||||
def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig:
|
||||
"""Remap config attributes to match the expected names."""
|
||||
for old_attr, new_attr in _CONFIG_ATTRS_MAPPING.items():
|
||||
@@ -423,6 +434,23 @@ def get_config(
|
||||
raise e
|
||||
config = _maybe_remap_hf_config_attrs(config)
|
||||
|
||||
# Phi4Flash misuses this config as list[int]. Convert it to int and add
|
||||
# the layer_types list[str] to make it HF compatible
|
||||
if (config.model_type == "phi4flash"):
|
||||
# TODO: Remove after the following PR is merged:
|
||||
# https://huggingface.co/microsoft/Phi-4-mini-flash-reasoning/discussions/6
|
||||
if not hasattr(config, "layer_types"):
|
||||
config.layer_types = [
|
||||
"sliding_attention" if i < config.num_hidden_layers // 2
|
||||
and i % 2 == 1 else "full_attention"
|
||||
for i in range(config.num_hidden_layers)
|
||||
]
|
||||
# TODO: Remove after the following PR is merged:
|
||||
# https://huggingface.co/microsoft/Phi-4-mini-flash-reasoning/discussions/7
|
||||
if isinstance(config.sliding_window, list):
|
||||
config.sliding_window = next(
|
||||
filter(None, config.sliding_window), None)
|
||||
|
||||
elif config_format == ConfigFormat.MISTRAL:
|
||||
# This function loads a params.json config which
|
||||
# should be used when loading models in mistral format
|
||||
@@ -434,6 +462,18 @@ def get_config(
|
||||
config_dict["max_position_embeddings"] = max_position_embeddings
|
||||
|
||||
config = adapt_config_dict(config_dict)
|
||||
|
||||
# Mistral configs may define sliding_window as list[int]. Convert it
|
||||
# to int and add the layer_types list[str] to make it HF compatible
|
||||
if ((sliding_window := getattr(config, "sliding_window", None))
|
||||
and isinstance(sliding_window, list)):
|
||||
pattern_repeats = config.num_hidden_layers // len(sliding_window)
|
||||
layer_types = sliding_window * pattern_repeats
|
||||
config.layer_types = [
|
||||
"full_attention" if layer_type is None else "sliding_attention"
|
||||
for layer_type in layer_types
|
||||
]
|
||||
config.sliding_window = next(filter(None, sliding_window), None)
|
||||
else:
|
||||
supported_formats = [
|
||||
fmt.value for fmt in ConfigFormat if fmt != ConfigFormat.AUTO
|
||||
|
||||
Reference in New Issue
Block a user