[Bugfix][FailingTest]Fix test_model_load_with_params.py (#18758)
Signed-off-by: rabi <ramishra@redhat.com>
This commit is contained in:
@@ -4,7 +4,7 @@ import os
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.model_executor.layers.pooler import CLSPool, PoolingType
|
||||
from vllm.model_executor.layers.pooler import CLSPool, MeanPool, PoolingType
|
||||
from vllm.model_executor.models.bert import BertEmbeddingModel
|
||||
from vllm.model_executor.models.roberta import RobertaEmbeddingModel
|
||||
from vllm.platforms import current_platform
|
||||
@@ -14,7 +14,7 @@ MODEL_NAME = os.environ.get("MODEL_NAME", "BAAI/bge-base-en-v1.5")
|
||||
REVISION = os.environ.get("REVISION", "main")
|
||||
|
||||
MODEL_NAME_ROBERTA = os.environ.get("MODEL_NAME",
|
||||
"intfloat/multilingual-e5-small")
|
||||
"intfloat/multilingual-e5-base")
|
||||
REVISION_ROBERTA = os.environ.get("REVISION", "main")
|
||||
|
||||
|
||||
@@ -40,17 +40,15 @@ def test_model_loading_with_params(vllm_runner):
|
||||
|
||||
# asserts on the pooling config files
|
||||
assert model_config.pooler_config.pooling_type == PoolingType.CLS.name
|
||||
assert model_config.pooler_config.pooling_norm
|
||||
assert model_config.pooler_config.normalize
|
||||
|
||||
# asserts on the tokenizer loaded
|
||||
assert model_tokenizer.tokenizer_id == "BAAI/bge-base-en-v1.5"
|
||||
assert model_tokenizer.tokenizer_config["do_lower_case"]
|
||||
assert model_tokenizer.tokenizer.model_max_length == 512
|
||||
|
||||
def check_model(model):
|
||||
assert isinstance(model, BertEmbeddingModel)
|
||||
assert model._pooler.pooling_type == PoolingType.CLS
|
||||
assert model._pooler.normalize
|
||||
assert isinstance(model._pooler, CLSPool)
|
||||
|
||||
vllm_model.apply_model(check_model)
|
||||
|
||||
@@ -80,16 +78,15 @@ def test_roberta_model_loading_with_params(vllm_runner):
|
||||
|
||||
# asserts on the pooling config files
|
||||
assert model_config.pooler_config.pooling_type == PoolingType.MEAN.name
|
||||
assert model_config.pooler_config.pooling_norm
|
||||
assert model_config.pooler_config.normalize
|
||||
|
||||
# asserts on the tokenizer loaded
|
||||
assert model_tokenizer.tokenizer_id == "intfloat/multilingual-e5-small"
|
||||
assert not model_tokenizer.tokenizer_config["do_lower_case"]
|
||||
assert model_tokenizer.tokenizer_id == "intfloat/multilingual-e5-base"
|
||||
assert model_tokenizer.tokenizer.model_max_length == 512
|
||||
|
||||
def check_model(model):
|
||||
assert isinstance(model, RobertaEmbeddingModel)
|
||||
assert model._pooler.pooling_type == PoolingType.MEAN
|
||||
assert model._pooler.normalize
|
||||
assert isinstance(model._pooler, MeanPool)
|
||||
|
||||
vllm_model.apply_model(check_model)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user