[Frontend][Doc][5/N] Improve all pooling task | Polish encode (pooling) api & Document. (#25524)
Signed-off-by: wang.yuqi <noooop@126.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
@@ -17,7 +17,7 @@ EMBEDDING_MODELS = [
|
||||
),
|
||||
]
|
||||
|
||||
classify_parameters = ["activation"]
|
||||
classify_parameters = ["use_activation"]
|
||||
embed_parameters = ["dimensions", "normalize"]
|
||||
step_pooling_parameters = ["step_tag_id", "returned_token_ids"]
|
||||
|
||||
@@ -88,13 +88,13 @@ def test_embed_dimensions(model_info: EmbedModelInfo):
|
||||
def test_classify(task):
|
||||
model_config = MockModelConfig(pooler_config=PoolerConfig(pooling_type="CLS"))
|
||||
|
||||
pooling_params = PoolingParams(activation=None)
|
||||
pooling_params = PoolingParams(use_activation=None)
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
|
||||
pooling_params = PoolingParams(activation=True)
|
||||
pooling_params = PoolingParams(use_activation=True)
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
|
||||
pooling_params = PoolingParams(activation=False)
|
||||
pooling_params = PoolingParams(use_activation=False)
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
|
||||
invalid_parameters = embed_parameters + step_pooling_parameters
|
||||
@@ -137,13 +137,13 @@ def test_token_classify(pooling_type: str):
|
||||
pooler_config=PoolerConfig(pooling_type=pooling_type)
|
||||
)
|
||||
|
||||
pooling_params = PoolingParams(activation=None)
|
||||
pooling_params = PoolingParams(use_activation=None)
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
|
||||
pooling_params = PoolingParams(activation=True)
|
||||
pooling_params = PoolingParams(use_activation=True)
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
|
||||
pooling_params = PoolingParams(activation=False)
|
||||
pooling_params = PoolingParams(use_activation=False)
|
||||
pooling_params.verify(task=task, model_config=model_config)
|
||||
|
||||
invalid_parameters = embed_parameters
|
||||
|
||||
Reference in New Issue
Block a user