Move PoolerConfig from config/__init__.py to config/pooler.py (#25181)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -207,25 +207,19 @@ def test_get_pooling_config():
|
||||
model_id = "sentence-transformers/all-MiniLM-L12-v2"
|
||||
model_config = ModelConfig(model_id)
|
||||
|
||||
pooling_config = model_config._init_pooler_config()
|
||||
assert pooling_config is not None
|
||||
|
||||
assert pooling_config.normalize
|
||||
assert pooling_config.pooling_type == PoolingType.MEAN.name
|
||||
assert model_config.pooler_config is not None
|
||||
assert model_config.pooler_config.normalize
|
||||
assert model_config.pooler_config.pooling_type == PoolingType.MEAN.name
|
||||
|
||||
|
||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||
reason="Xformers backend is not supported on ROCm.")
|
||||
def test_get_pooling_config_from_args():
|
||||
model_id = "sentence-transformers/all-MiniLM-L12-v2"
|
||||
model_config = ModelConfig(model_id)
|
||||
pooler_config = PoolerConfig(pooling_type="CLS", normalize=True)
|
||||
model_config = ModelConfig(model_id, pooler_config=pooler_config)
|
||||
|
||||
override_pooler_config = PoolerConfig(pooling_type='CLS', normalize=True)
|
||||
model_config.override_pooler_config = override_pooler_config
|
||||
|
||||
pooling_config = model_config._init_pooler_config()
|
||||
assert pooling_config is not None
|
||||
assert asdict(pooling_config) == asdict(override_pooler_config)
|
||||
assert asdict(model_config.pooler_config) == asdict(pooler_config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
||||
Reference in New Issue
Block a user