Update rope_scaling to rope_parameters in preparation for Transformers v5 (#28542)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-11-19 18:06:36 +01:00
committed by GitHub
parent d44e9df7d4
commit a8b70304d6
104 changed files with 542 additions and 910 deletions

View File

@@ -7,8 +7,9 @@ import time
from collections.abc import Callable
from dataclasses import asdict
from functools import cache, partial
from importlib.metadata import version
from pathlib import Path
from typing import Any, Literal, TypeVar
from typing import Any, Literal, TypeAlias, TypeVar
import huggingface_hub
from huggingface_hub import (
@@ -24,7 +25,9 @@ from huggingface_hub.utils import (
RepositoryNotFoundError,
RevisionNotFoundError,
)
from packaging.version import Version
from transformers import DeepseekV3Config, GenerationConfig, PretrainedConfig
from transformers.configuration_utils import ALLOWED_LAYER_TYPES
from transformers.models.auto.image_processing_auto import get_image_processor_config
from transformers.models.auto.modeling_auto import (
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
@@ -390,21 +393,61 @@ def file_or_path_exists(
)
def patch_rope_scaling(config: PretrainedConfig) -> None:
def set_default_rope_theta(config: PretrainedConfig, default_theta: float) -> None:
"""Some models may have no rope_theta in their config but still use RoPE.
This function sets a default rope_theta if it's missing."""
if getattr(config, "rope_parameters", None) is None:
config.rope_parameters = {"rope_type": "default"}
if "rope_theta" not in config.rope_parameters:
config.rope_parameters["rope_theta"] = default_theta
def patch_rope_parameters(config: PretrainedConfig) -> None:
"""Provide backwards compatibility for RoPE."""
text_config = getattr(config, "text_config", None)
if text_config is not None:
patch_rope_scaling(text_config)
# Retrieve rope_parameters differently based on Transformers version
if Version(version("transformers")) >= Version("5.0.0.dev0"):
from transformers.modeling_rope_utils import RopeParameters
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is not None:
patch_rope_scaling_dict(rope_scaling)
rope_parameters: RopeParameters | dict[str, RopeParameters] | None = getattr(
config, "rope_parameters", None
)
elif hasattr(config, "rope_parameters"):
# We are in Transformers v4 and rope_parameters
# has already been patched for this config
return
else:
# Convert Transformers v4 rope_theta and rope_scaling into rope_parameters
rope_theta: float | None = getattr(config, "rope_theta", None)
rope_scaling: dict | None = getattr(config, "rope_scaling", None)
rope_parameters = rope_scaling
# Move rope_theta into rope_parameters
if rope_theta is not None:
rope_parameters = rope_parameters or {"rope_type": "default"}
rope_parameters["rope_theta"] = rope_theta
# Add original_max_position_embeddings if present
if rope_parameters and (
ompe := getattr(config, "original_max_position_embeddings", None)
):
rope_parameters["original_max_position_embeddings"] = ompe
# Write back to config
config.rope_parameters = rope_parameters
# No RoPE parameters to patch
if rope_parameters is None:
return
# Handle nested rope_parameters in interleaved sliding attention models
if set(rope_parameters.keys()).issubset(ALLOWED_LAYER_TYPES):
for rope_parameters_layer_type in rope_parameters.values():
patch_rope_parameters_dict(rope_parameters_layer_type)
else:
patch_rope_parameters_dict(rope_parameters)
def patch_rope_scaling_dict(rope_scaling: dict[str, Any]) -> None:
if "rope_type" in rope_scaling and "type" in rope_scaling:
rope_type = rope_scaling["rope_type"]
rope_type_legacy = rope_scaling["type"]
def patch_rope_parameters_dict(rope_parameters: dict[str, Any]) -> None:
if "rope_type" in rope_parameters and "type" in rope_parameters:
rope_type = rope_parameters["rope_type"]
rope_type_legacy = rope_parameters["type"]
if rope_type != rope_type_legacy:
raise ValueError(
f"Found conflicts between 'rope_type={rope_type}' (modern "
@@ -412,28 +455,28 @@ def patch_rope_scaling_dict(rope_scaling: dict[str, Any]) -> None:
"You should only specify one of them."
)
if "rope_type" not in rope_scaling and "type" in rope_scaling:
rope_scaling["rope_type"] = rope_scaling["type"]
if "rope_type" not in rope_parameters and "type" in rope_parameters:
rope_parameters["rope_type"] = rope_parameters["type"]
logger.info("Replacing legacy 'type' key with 'rope_type'")
if "rope_type" not in rope_scaling:
raise ValueError("rope_scaling should have a 'rope_type' key")
if "rope_type" not in rope_parameters:
raise ValueError("rope_parameters should have a 'rope_type' key")
if rope_scaling["rope_type"] == "su":
rope_scaling["rope_type"] = "longrope"
if rope_parameters["rope_type"] == "su":
rope_parameters["rope_type"] = "longrope"
logger.warning("Replacing legacy rope_type 'su' with 'longrope'")
elif rope_scaling["rope_type"] == "mrope":
assert "mrope_section" in rope_scaling
rope_scaling["rope_type"] = "default"
elif rope_parameters["rope_type"] == "mrope":
assert "mrope_section" in rope_parameters
rope_parameters["rope_type"] = "default"
logger.warning("Replacing legacy rope_type 'mrope' with 'default'")
def _uses_mrope(config: PretrainedConfig) -> bool:
rope_scaling = getattr(config, "rope_scaling", None)
if rope_scaling is None:
rope_parameters = getattr(config, "rope_parameters", None)
if rope_parameters is None:
return False
return "mrope_section" in rope_scaling
return "mrope_section" in rope_parameters
def uses_mrope(config: PretrainedConfig) -> bool:
@@ -690,7 +733,14 @@ def get_config(
logger.debug("Overriding HF config with %s", hf_overrides_fn)
config = hf_overrides_fn(config)
patch_rope_scaling(config)
# Exhaustively patch RoPE parameters everywhere they might be
patch_rope_parameters(config)
patch_rope_parameters(config.get_text_config())
SubConfigs: TypeAlias = dict[str, PretrainedConfig]
sub_configs: SubConfigs | None = getattr(config, "sub_configs", None)
if sub_configs:
for sub_config in sub_configs:
patch_rope_parameters(getattr(config, sub_config))
if trust_remote_code:
maybe_register_config_serialize_by_value()