[Bugfix] Replace PoolingParams.normalize with use_activation (#32243)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -18,7 +18,7 @@ EMBEDDING_MODELS = [
|
||||
]
|
||||
|
||||
classify_parameters = ["use_activation"]
|
||||
embed_parameters = ["dimensions", "normalize"]
|
||||
embed_parameters = ["dimensions", "use_activation"]
|
||||
step_pooling_parameters = ["step_tag_id", "returned_token_ids"]
|
||||
|
||||
|
||||
@@ -42,17 +42,17 @@ def test_embed():
|
||||
task = "embed"
|
||||
model_config = MockModelConfig(pooler_config=PoolerConfig(seq_pooling_type="CLS"))
|
||||
|
||||
pooling_params = PoolingParams(normalize=None)
|
||||
pooling_params = PoolingParams(use_activation=None)
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
|
||||
pooling_params = PoolingParams(normalize=True)
|
||||
pooling_params = PoolingParams(use_activation=True)
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
|
||||
pooling_params = PoolingParams(normalize=False)
|
||||
pooling_params = PoolingParams(use_activation=False)
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
|
||||
invalid_parameters = classify_parameters + step_pooling_parameters
|
||||
for p in invalid_parameters:
|
||||
for p in set(invalid_parameters) - set(embed_parameters):
|
||||
with pytest.raises(ValueError):
|
||||
pooling_params = PoolingParams(**{p: True})
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
@@ -98,7 +98,7 @@ def test_classify(task):
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
|
||||
invalid_parameters = embed_parameters + step_pooling_parameters
|
||||
for p in invalid_parameters:
|
||||
for p in set(invalid_parameters) - set(classify_parameters):
|
||||
with pytest.raises(ValueError):
|
||||
pooling_params = PoolingParams(**{p: True})
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
@@ -111,20 +111,20 @@ def test_token_embed(pooling_type: str):
|
||||
pooler_config=PoolerConfig(tok_pooling_type=pooling_type)
|
||||
)
|
||||
|
||||
pooling_params = PoolingParams(normalize=None)
|
||||
pooling_params = PoolingParams(use_activation=None)
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
|
||||
pooling_params = PoolingParams(normalize=True)
|
||||
pooling_params = PoolingParams(use_activation=True)
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
|
||||
pooling_params = PoolingParams(normalize=False)
|
||||
pooling_params = PoolingParams(use_activation=False)
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
|
||||
invalid_parameters = classify_parameters
|
||||
if pooling_type != "STEP":
|
||||
invalid_parameters = classify_parameters + step_pooling_parameters
|
||||
|
||||
for p in invalid_parameters:
|
||||
for p in set(invalid_parameters) - set(embed_parameters):
|
||||
with pytest.raises(ValueError):
|
||||
pooling_params = PoolingParams(**{p: True})
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
@@ -150,7 +150,7 @@ def test_token_classify(pooling_type: str):
|
||||
if pooling_type != "STEP":
|
||||
invalid_parameters = embed_parameters + step_pooling_parameters
|
||||
|
||||
for p in invalid_parameters:
|
||||
for p in set(invalid_parameters) - set(classify_parameters):
|
||||
with pytest.raises(ValueError):
|
||||
pooling_params = PoolingParams(**{p: True})
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
|
||||
Reference in New Issue
Block a user