Move PoolerConfig from config/__init__.py to config/pooler.py (#25181)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-09-19 12:02:55 +01:00
committed by GitHub
parent 1dfea5f4a9
commit 058525b997
14 changed files with 193 additions and 179 deletions

View File

@@ -207,25 +207,19 @@ def test_get_pooling_config():
model_id = "sentence-transformers/all-MiniLM-L12-v2"
model_config = ModelConfig(model_id)
pooling_config = model_config._init_pooler_config()
assert pooling_config is not None
assert pooling_config.normalize
assert pooling_config.pooling_type == PoolingType.MEAN.name
assert model_config.pooler_config is not None
assert model_config.pooler_config.normalize
assert model_config.pooler_config.pooling_type == PoolingType.MEAN.name
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
def test_get_pooling_config_from_args():
model_id = "sentence-transformers/all-MiniLM-L12-v2"
model_config = ModelConfig(model_id)
pooler_config = PoolerConfig(pooling_type="CLS", normalize=True)
model_config = ModelConfig(model_id, pooler_config=pooler_config)
override_pooler_config = PoolerConfig(pooling_type='CLS', normalize=True)
model_config.override_pooler_config = override_pooler_config
pooling_config = model_config._init_pooler_config()
assert pooling_config is not None
assert asdict(pooling_config) == asdict(override_pooler_config)
assert asdict(model_config.pooler_config) == asdict(pooler_config)
@pytest.mark.parametrize(