Improve configs - ModelConfig (#17130)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-04-30 03:38:22 +01:00
committed by GitHub
parent 2c4f59afc3
commit 13698db634
36 changed files with 490 additions and 648 deletions

View File

@@ -185,7 +185,7 @@ def test_get_pooling_config():
revision=None,
)
pooling_config = model_config._init_pooler_config(None)
pooling_config = model_config._init_pooler_config()
assert pooling_config is not None
assert pooling_config.normalize
@@ -205,11 +205,12 @@ def test_get_pooling_config_from_args():
dtype="float16",
revision=None)
override_config = PoolerConfig(pooling_type='CLS', normalize=True)
override_pooler_config = PoolerConfig(pooling_type='CLS', normalize=True)
model_config.override_pooler_config = override_pooler_config
pooling_config = model_config._init_pooler_config(override_config)
pooling_config = model_config._init_pooler_config()
assert pooling_config is not None
assert asdict(pooling_config) == asdict(override_config)
assert asdict(pooling_config) == asdict(override_pooler_config)
@pytest.mark.skipif(current_platform.is_rocm(),