Improve configs - ModelConfig (#17130)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -185,7 +185,7 @@ def test_get_pooling_config():
|
||||
revision=None,
|
||||
)
|
||||
|
||||
pooling_config = model_config._init_pooler_config(None)
|
||||
pooling_config = model_config._init_pooler_config()
|
||||
assert pooling_config is not None
|
||||
|
||||
assert pooling_config.normalize
|
||||
@@ -205,11 +205,12 @@ def test_get_pooling_config_from_args():
|
||||
dtype="float16",
|
||||
revision=None)
|
||||
|
||||
override_config = PoolerConfig(pooling_type='CLS', normalize=True)
|
||||
override_pooler_config = PoolerConfig(pooling_type='CLS', normalize=True)
|
||||
model_config.override_pooler_config = override_pooler_config
|
||||
|
||||
pooling_config = model_config._init_pooler_config(override_config)
|
||||
pooling_config = model_config._init_pooler_config()
|
||||
assert pooling_config is not None
|
||||
assert asdict(pooling_config) == asdict(override_config)
|
||||
assert asdict(pooling_config) == asdict(override_pooler_config)
|
||||
|
||||
|
||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||
|
||||
Reference in New Issue
Block a user