[Frontend] Dynamic RoPE scaling (#4638)
This commit is contained in:
@@ -36,4 +36,58 @@ def test_get_sliding_window():
|
||||
assert mistral_model_config.get_sliding_window() is None
|
||||
|
||||
mistral_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
|
||||
assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
|
||||
assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
|
||||
|
||||
|
||||
def test_rope_scaling():
|
||||
TEST_ROPE_SCALING = {"type": "dynamic", "factor": 2.0}
|
||||
LONGCHAT_ROPE_SCALING = {"type": "linear", "factor": 8.0}
|
||||
|
||||
llama_model_config = ModelConfig(
|
||||
"meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
"meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
dtype="float16",
|
||||
seed=0,
|
||||
)
|
||||
assert getattr(llama_model_config.hf_config, "rope_scaling", None) is None
|
||||
assert llama_model_config.max_model_len == 8192
|
||||
|
||||
llama_model_config = ModelConfig(
|
||||
"meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
"meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
dtype="float16",
|
||||
seed=0,
|
||||
rope_scaling=TEST_ROPE_SCALING,
|
||||
)
|
||||
assert getattr(llama_model_config.hf_config, "rope_scaling",
|
||||
None) == TEST_ROPE_SCALING
|
||||
assert llama_model_config.max_model_len == 16384
|
||||
|
||||
longchat_model_config = ModelConfig(
|
||||
"lmsys/longchat-13b-16k",
|
||||
"lmsys/longchat-13b-16k",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
dtype="float16",
|
||||
seed=0,
|
||||
)
|
||||
assert getattr(longchat_model_config.hf_config, "rope_scaling",
|
||||
None) == LONGCHAT_ROPE_SCALING
|
||||
assert longchat_model_config.max_model_len == 16384
|
||||
|
||||
longchat_model_config = ModelConfig(
|
||||
"lmsys/longchat-13b-16k",
|
||||
"lmsys/longchat-13b-16k",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
dtype="float16",
|
||||
seed=0,
|
||||
rope_scaling=TEST_ROPE_SCALING,
|
||||
)
|
||||
assert getattr(longchat_model_config.hf_config, "rope_scaling",
|
||||
None) == TEST_ROPE_SCALING
|
||||
assert longchat_model_config.max_model_len == 4096
|
||||
|
||||
Reference in New Issue
Block a user