[Misc] Consolidate ModelConfig code related to HF config (#10104)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -165,3 +165,41 @@ def test_rope_customization():
|
||||
assert getattr(longchat_model_config.hf_config, "rope_scaling",
|
||||
None) == TEST_ROPE_SCALING
|
||||
assert longchat_model_config.max_model_len == 4096
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model_id", "is_encoder_decoder"), [
|
||||
("facebook/opt-125m", False),
|
||||
("facebook/bart-base", True),
|
||||
("meta-llama/Llama-3.2-1B", False),
|
||||
("meta-llama/Llama-3.2-11B-Vision", True),
|
||||
])
|
||||
def test_is_encoder_decoder(model_id, is_encoder_decoder):
|
||||
config = ModelConfig(
|
||||
model_id,
|
||||
task="auto",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
dtype="float16",
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert config.is_encoder_decoder == is_encoder_decoder
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model_id", "uses_mrope"), [
|
||||
("facebook/opt-125m", False),
|
||||
("Qwen/Qwen2-VL-2B-Instruct", True),
|
||||
])
|
||||
def test_uses_mrope(model_id, uses_mrope):
|
||||
config = ModelConfig(
|
||||
model_id,
|
||||
task="auto",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
dtype="float16",
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert config.uses_mrope == uses_mrope
|
||||
|
||||
Reference in New Issue
Block a user