[Refactor] Separate sequence and token pooling types (#32026)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -40,7 +40,7 @@ def test_task():
|
||||
|
||||
def test_embed():
|
||||
task = "embed"
|
||||
model_config = MockModelConfig(pooler_config=PoolerConfig(pooling_type="CLS"))
|
||||
model_config = MockModelConfig(pooler_config=PoolerConfig(seq_pooling_type="CLS"))
|
||||
|
||||
pooling_params = PoolingParams(normalize=None)
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
@@ -86,7 +86,7 @@ def test_embed_dimensions(model_info: EmbedModelInfo):
|
||||
|
||||
@pytest.mark.parametrize("task", ["score", "classify"])
|
||||
def test_classify(task):
|
||||
model_config = MockModelConfig(pooler_config=PoolerConfig(pooling_type="CLS"))
|
||||
model_config = MockModelConfig(pooler_config=PoolerConfig(seq_pooling_type="CLS"))
|
||||
|
||||
pooling_params = PoolingParams(use_activation=None)
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
@@ -108,7 +108,7 @@ def test_classify(task):
|
||||
def test_token_embed(pooling_type: str):
|
||||
task = "token_embed"
|
||||
model_config = MockModelConfig(
|
||||
pooler_config=PoolerConfig(pooling_type=pooling_type)
|
||||
pooler_config=PoolerConfig(tok_pooling_type=pooling_type)
|
||||
)
|
||||
|
||||
pooling_params = PoolingParams(normalize=None)
|
||||
@@ -134,7 +134,7 @@ def test_token_embed(pooling_type: str):
|
||||
def test_token_classify(pooling_type: str):
|
||||
task = "token_classify"
|
||||
model_config = MockModelConfig(
|
||||
pooler_config=PoolerConfig(pooling_type=pooling_type)
|
||||
pooler_config=PoolerConfig(tok_pooling_type=pooling_type)
|
||||
)
|
||||
|
||||
pooling_params = PoolingParams(use_activation=None)
|
||||
|
||||
Reference in New Issue
Block a user