[Chore] Further cleanup pooler (#31951)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -5,12 +5,7 @@ import os
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.model_executor.layers.pooler import (
|
||||
CLSPool,
|
||||
DispatchPooler,
|
||||
MeanPool,
|
||||
PoolingType,
|
||||
)
|
||||
from vllm.model_executor.layers.pooler import CLSPool, DispatchPooler, MeanPool
|
||||
from vllm.model_executor.models.bert import BertEmbeddingModel
|
||||
from vllm.model_executor.models.roberta import RobertaEmbeddingModel
|
||||
from vllm.platforms import current_platform
|
||||
@@ -50,7 +45,7 @@ def test_model_loading_with_params(vllm_runner, monkeypatch):
|
||||
assert model_config.encoder_config["do_lower_case"]
|
||||
|
||||
# asserts on the pooling config files
|
||||
assert model_config.pooler_config.pooling_type == PoolingType.CLS.name
|
||||
assert model_config.pooler_config.pooling_type == "CLS"
|
||||
assert model_config.pooler_config.normalize
|
||||
|
||||
# asserts on the tokenizer loaded
|
||||
@@ -94,7 +89,7 @@ def test_roberta_model_loading_with_params(vllm_runner, monkeypatch):
|
||||
assert not model_config.encoder_config["do_lower_case"]
|
||||
|
||||
# asserts on the pooling config files
|
||||
assert model_config.pooler_config.pooling_type == PoolingType.MEAN.name
|
||||
assert model_config.pooler_config.pooling_type == "MEAN"
|
||||
assert model_config.pooler_config.normalize
|
||||
|
||||
# asserts on the tokenizer loaded
|
||||
|
||||
Reference in New Issue
Block a user