Update rope_scaling to rope_parameters in preparation for Transformers v5 (#28542)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-11-19 18:06:36 +01:00
committed by GitHub
parent d44e9df7d4
commit a8b70304d6
104 changed files with 542 additions and 910 deletions

View File

@@ -249,45 +249,48 @@ def test_get_bert_tokenization_sentence_transformer_config():
def test_rope_customization():
TEST_ROPE_SCALING = {"rope_type": "dynamic", "factor": 2.0}
TEST_ROPE_THETA = 16_000_000.0
LONGCHAT_ROPE_SCALING = {"rope_type": "linear", "factor": 8.0}
TEST_ROPE_PARAMETERS = {
"rope_theta": 16_000_000.0,
"rope_type": "dynamic",
"factor": 2.0,
}
LLAMA_ROPE_PARAMETERS = {"rope_theta": 500000.0, "rope_type": "default"}
LONGCHAT_ROPE_PARAMETERS = {"rope_type": "linear", "factor": 8.0}
llama_model_config = ModelConfig("meta-llama/Meta-Llama-3-8B-Instruct")
assert getattr(llama_model_config.hf_config, "rope_scaling", None) is None
assert getattr(llama_model_config.hf_config, "rope_theta", None) == 500_000
assert (
getattr(llama_model_config.hf_config, "rope_parameters", None)
== LLAMA_ROPE_PARAMETERS
)
assert llama_model_config.max_model_len == 8192
llama_model_config = ModelConfig(
"meta-llama/Meta-Llama-3-8B-Instruct",
hf_overrides={
"rope_scaling": TEST_ROPE_SCALING,
"rope_theta": TEST_ROPE_THETA,
},
hf_overrides={"rope_parameters": TEST_ROPE_PARAMETERS},
)
assert (
getattr(llama_model_config.hf_config, "rope_scaling", None) == TEST_ROPE_SCALING
getattr(llama_model_config.hf_config, "rope_parameters", None)
== TEST_ROPE_PARAMETERS
)
assert getattr(llama_model_config.hf_config, "rope_theta", None) == TEST_ROPE_THETA
assert llama_model_config.max_model_len == 16384
longchat_model_config = ModelConfig("lmsys/longchat-13b-16k")
# Check if LONGCHAT_ROPE_SCALING entries are in longchat_model_config
# Check if LONGCHAT_ROPE_PARAMETERS entries are in longchat_model_config
assert all(
longchat_model_config.hf_config.rope_scaling.get(key) == value
for key, value in LONGCHAT_ROPE_SCALING.items()
longchat_model_config.hf_config.rope_parameters.get(key) == value
for key, value in LONGCHAT_ROPE_PARAMETERS.items()
)
assert longchat_model_config.max_model_len == 16384
longchat_model_config = ModelConfig(
"lmsys/longchat-13b-16k",
hf_overrides={
"rope_scaling": TEST_ROPE_SCALING,
"rope_parameters": TEST_ROPE_PARAMETERS,
},
)
assert (
getattr(longchat_model_config.hf_config, "rope_scaling", None)
== TEST_ROPE_SCALING
getattr(longchat_model_config.hf_config, "rope_parameters", None)
== TEST_ROPE_PARAMETERS
)
assert longchat_model_config.max_model_len == 4096