Update rope_scaling to rope_parameters in preparation for Transformers v5 (#28542)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -249,45 +249,48 @@ def test_get_bert_tokenization_sentence_transformer_config():
|
||||
|
||||
|
||||
def test_rope_customization():
|
||||
TEST_ROPE_SCALING = {"rope_type": "dynamic", "factor": 2.0}
|
||||
TEST_ROPE_THETA = 16_000_000.0
|
||||
LONGCHAT_ROPE_SCALING = {"rope_type": "linear", "factor": 8.0}
|
||||
TEST_ROPE_PARAMETERS = {
|
||||
"rope_theta": 16_000_000.0,
|
||||
"rope_type": "dynamic",
|
||||
"factor": 2.0,
|
||||
}
|
||||
LLAMA_ROPE_PARAMETERS = {"rope_theta": 500000.0, "rope_type": "default"}
|
||||
LONGCHAT_ROPE_PARAMETERS = {"rope_type": "linear", "factor": 8.0}
|
||||
|
||||
llama_model_config = ModelConfig("meta-llama/Meta-Llama-3-8B-Instruct")
|
||||
assert getattr(llama_model_config.hf_config, "rope_scaling", None) is None
|
||||
assert getattr(llama_model_config.hf_config, "rope_theta", None) == 500_000
|
||||
assert (
|
||||
getattr(llama_model_config.hf_config, "rope_parameters", None)
|
||||
== LLAMA_ROPE_PARAMETERS
|
||||
)
|
||||
assert llama_model_config.max_model_len == 8192
|
||||
|
||||
llama_model_config = ModelConfig(
|
||||
"meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
hf_overrides={
|
||||
"rope_scaling": TEST_ROPE_SCALING,
|
||||
"rope_theta": TEST_ROPE_THETA,
|
||||
},
|
||||
hf_overrides={"rope_parameters": TEST_ROPE_PARAMETERS},
|
||||
)
|
||||
assert (
|
||||
getattr(llama_model_config.hf_config, "rope_scaling", None) == TEST_ROPE_SCALING
|
||||
getattr(llama_model_config.hf_config, "rope_parameters", None)
|
||||
== TEST_ROPE_PARAMETERS
|
||||
)
|
||||
assert getattr(llama_model_config.hf_config, "rope_theta", None) == TEST_ROPE_THETA
|
||||
assert llama_model_config.max_model_len == 16384
|
||||
|
||||
longchat_model_config = ModelConfig("lmsys/longchat-13b-16k")
|
||||
# Check if LONGCHAT_ROPE_SCALING entries are in longchat_model_config
|
||||
# Check if LONGCHAT_ROPE_PARAMETERS entries are in longchat_model_config
|
||||
assert all(
|
||||
longchat_model_config.hf_config.rope_scaling.get(key) == value
|
||||
for key, value in LONGCHAT_ROPE_SCALING.items()
|
||||
longchat_model_config.hf_config.rope_parameters.get(key) == value
|
||||
for key, value in LONGCHAT_ROPE_PARAMETERS.items()
|
||||
)
|
||||
assert longchat_model_config.max_model_len == 16384
|
||||
|
||||
longchat_model_config = ModelConfig(
|
||||
"lmsys/longchat-13b-16k",
|
||||
hf_overrides={
|
||||
"rope_scaling": TEST_ROPE_SCALING,
|
||||
"rope_parameters": TEST_ROPE_PARAMETERS,
|
||||
},
|
||||
)
|
||||
assert (
|
||||
getattr(longchat_model_config.hf_config, "rope_scaling", None)
|
||||
== TEST_ROPE_SCALING
|
||||
getattr(longchat_model_config.hf_config, "rope_parameters", None)
|
||||
== TEST_ROPE_PARAMETERS
|
||||
)
|
||||
assert longchat_model_config.max_model_len == 4096
|
||||
|
||||
|
||||
Reference in New Issue
Block a user