[Misc] Consolidate ModelConfig code related to HF config (#10104)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2024-11-07 14:00:21 +08:00
committed by GitHub
parent 1fa020c539
commit db7db4aab9
10 changed files with 68 additions and 43 deletions

View File

@@ -165,3 +165,41 @@ def test_rope_customization():
assert getattr(longchat_model_config.hf_config, "rope_scaling",
None) == TEST_ROPE_SCALING
assert longchat_model_config.max_model_len == 4096
@pytest.mark.parametrize(("model_id", "is_encoder_decoder"), [
("facebook/opt-125m", False),
("facebook/bart-base", True),
("meta-llama/Llama-3.2-1B", False),
("meta-llama/Llama-3.2-11B-Vision", True),
])
def test_is_encoder_decoder(model_id, is_encoder_decoder):
config = ModelConfig(
model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float16",
seed=0,
)
assert config.is_encoder_decoder == is_encoder_decoder
@pytest.mark.parametrize(("model_id", "uses_mrope"), [
("facebook/opt-125m", False),
("Qwen/Qwen2-VL-2B-Instruct", True),
])
def test_uses_mrope(model_id, uses_mrope):
config = ModelConfig(
model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float16",
seed=0,
)
assert config.uses_mrope == uses_mrope