Update rope_scaling to rope_parameters in preparation for Transformers v5 (#28542)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -7,8 +7,9 @@ import time
|
||||
from collections.abc import Callable
|
||||
from dataclasses import asdict
|
||||
from functools import cache, partial
|
||||
from importlib.metadata import version
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, TypeVar
|
||||
from typing import Any, Literal, TypeAlias, TypeVar
|
||||
|
||||
import huggingface_hub
|
||||
from huggingface_hub import (
|
||||
@@ -24,7 +25,9 @@ from huggingface_hub.utils import (
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
)
|
||||
from packaging.version import Version
|
||||
from transformers import DeepseekV3Config, GenerationConfig, PretrainedConfig
|
||||
from transformers.configuration_utils import ALLOWED_LAYER_TYPES
|
||||
from transformers.models.auto.image_processing_auto import get_image_processor_config
|
||||
from transformers.models.auto.modeling_auto import (
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
|
||||
@@ -390,21 +393,61 @@ def file_or_path_exists(
|
||||
)
|
||||
|
||||
|
||||
def patch_rope_scaling(config: PretrainedConfig) -> None:
|
||||
def set_default_rope_theta(config: PretrainedConfig, default_theta: float) -> None:
|
||||
"""Some models may have no rope_theta in their config but still use RoPE.
|
||||
This function sets a default rope_theta if it's missing."""
|
||||
if getattr(config, "rope_parameters", None) is None:
|
||||
config.rope_parameters = {"rope_type": "default"}
|
||||
if "rope_theta" not in config.rope_parameters:
|
||||
config.rope_parameters["rope_theta"] = default_theta
|
||||
|
||||
|
||||
def patch_rope_parameters(config: PretrainedConfig) -> None:
|
||||
"""Provide backwards compatibility for RoPE."""
|
||||
text_config = getattr(config, "text_config", None)
|
||||
if text_config is not None:
|
||||
patch_rope_scaling(text_config)
|
||||
# Retrieve rope_parameters differently based on Transformers version
|
||||
if Version(version("transformers")) >= Version("5.0.0.dev0"):
|
||||
from transformers.modeling_rope_utils import RopeParameters
|
||||
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
if rope_scaling is not None:
|
||||
patch_rope_scaling_dict(rope_scaling)
|
||||
rope_parameters: RopeParameters | dict[str, RopeParameters] | None = getattr(
|
||||
config, "rope_parameters", None
|
||||
)
|
||||
elif hasattr(config, "rope_parameters"):
|
||||
# We are in Transformers v4 and rope_parameters
|
||||
# has already been patched for this config
|
||||
return
|
||||
else:
|
||||
# Convert Transformers v4 rope_theta and rope_scaling into rope_parameters
|
||||
rope_theta: float | None = getattr(config, "rope_theta", None)
|
||||
rope_scaling: dict | None = getattr(config, "rope_scaling", None)
|
||||
rope_parameters = rope_scaling
|
||||
# Move rope_theta into rope_parameters
|
||||
if rope_theta is not None:
|
||||
rope_parameters = rope_parameters or {"rope_type": "default"}
|
||||
rope_parameters["rope_theta"] = rope_theta
|
||||
# Add original_max_position_embeddings if present
|
||||
if rope_parameters and (
|
||||
ompe := getattr(config, "original_max_position_embeddings", None)
|
||||
):
|
||||
rope_parameters["original_max_position_embeddings"] = ompe
|
||||
# Write back to config
|
||||
config.rope_parameters = rope_parameters
|
||||
|
||||
# No RoPE parameters to patch
|
||||
if rope_parameters is None:
|
||||
return
|
||||
|
||||
# Handle nested rope_parameters in interleaved sliding attention models
|
||||
if set(rope_parameters.keys()).issubset(ALLOWED_LAYER_TYPES):
|
||||
for rope_parameters_layer_type in rope_parameters.values():
|
||||
patch_rope_parameters_dict(rope_parameters_layer_type)
|
||||
else:
|
||||
patch_rope_parameters_dict(rope_parameters)
|
||||
|
||||
|
||||
def patch_rope_scaling_dict(rope_scaling: dict[str, Any]) -> None:
|
||||
if "rope_type" in rope_scaling and "type" in rope_scaling:
|
||||
rope_type = rope_scaling["rope_type"]
|
||||
rope_type_legacy = rope_scaling["type"]
|
||||
def patch_rope_parameters_dict(rope_parameters: dict[str, Any]) -> None:
|
||||
if "rope_type" in rope_parameters and "type" in rope_parameters:
|
||||
rope_type = rope_parameters["rope_type"]
|
||||
rope_type_legacy = rope_parameters["type"]
|
||||
if rope_type != rope_type_legacy:
|
||||
raise ValueError(
|
||||
f"Found conflicts between 'rope_type={rope_type}' (modern "
|
||||
@@ -412,28 +455,28 @@ def patch_rope_scaling_dict(rope_scaling: dict[str, Any]) -> None:
|
||||
"You should only specify one of them."
|
||||
)
|
||||
|
||||
if "rope_type" not in rope_scaling and "type" in rope_scaling:
|
||||
rope_scaling["rope_type"] = rope_scaling["type"]
|
||||
if "rope_type" not in rope_parameters and "type" in rope_parameters:
|
||||
rope_parameters["rope_type"] = rope_parameters["type"]
|
||||
logger.info("Replacing legacy 'type' key with 'rope_type'")
|
||||
|
||||
if "rope_type" not in rope_scaling:
|
||||
raise ValueError("rope_scaling should have a 'rope_type' key")
|
||||
if "rope_type" not in rope_parameters:
|
||||
raise ValueError("rope_parameters should have a 'rope_type' key")
|
||||
|
||||
if rope_scaling["rope_type"] == "su":
|
||||
rope_scaling["rope_type"] = "longrope"
|
||||
if rope_parameters["rope_type"] == "su":
|
||||
rope_parameters["rope_type"] = "longrope"
|
||||
logger.warning("Replacing legacy rope_type 'su' with 'longrope'")
|
||||
elif rope_scaling["rope_type"] == "mrope":
|
||||
assert "mrope_section" in rope_scaling
|
||||
rope_scaling["rope_type"] = "default"
|
||||
elif rope_parameters["rope_type"] == "mrope":
|
||||
assert "mrope_section" in rope_parameters
|
||||
rope_parameters["rope_type"] = "default"
|
||||
logger.warning("Replacing legacy rope_type 'mrope' with 'default'")
|
||||
|
||||
|
||||
def _uses_mrope(config: PretrainedConfig) -> bool:
|
||||
rope_scaling = getattr(config, "rope_scaling", None)
|
||||
if rope_scaling is None:
|
||||
rope_parameters = getattr(config, "rope_parameters", None)
|
||||
if rope_parameters is None:
|
||||
return False
|
||||
|
||||
return "mrope_section" in rope_scaling
|
||||
return "mrope_section" in rope_parameters
|
||||
|
||||
|
||||
def uses_mrope(config: PretrainedConfig) -> bool:
|
||||
@@ -690,7 +733,14 @@ def get_config(
|
||||
logger.debug("Overriding HF config with %s", hf_overrides_fn)
|
||||
config = hf_overrides_fn(config)
|
||||
|
||||
patch_rope_scaling(config)
|
||||
# Exhaustively patch RoPE parameters everywhere they might be
|
||||
patch_rope_parameters(config)
|
||||
patch_rope_parameters(config.get_text_config())
|
||||
SubConfigs: TypeAlias = dict[str, PretrainedConfig]
|
||||
sub_configs: SubConfigs | None = getattr(config, "sub_configs", None)
|
||||
if sub_configs:
|
||||
for sub_config in sub_configs:
|
||||
patch_rope_parameters(getattr(config, sub_config))
|
||||
|
||||
if trust_remote_code:
|
||||
maybe_register_config_serialize_by_value()
|
||||
|
||||
Reference in New Issue
Block a user