Adds method to read the pooling types from model's files (#9506)
Signed-off-by: Flavia Beo <flavia.beo@ibm.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Co-authored-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
import pytest
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.model_executor.layers.pooler import PoolingType
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model_id", "expected_task"), [
|
||||
@@ -102,6 +104,76 @@ def test_get_sliding_window():
|
||||
assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
|
||||
|
||||
|
||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||
reason="Xformers backend is not supported on ROCm.")
|
||||
def test_get_pooling_config():
|
||||
model_id = "sentence-transformers/all-MiniLM-L12-v2"
|
||||
minilm_model_config = ModelConfig(
|
||||
model_id,
|
||||
task="auto",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
)
|
||||
|
||||
minilm_pooling_config = minilm_model_config._init_pooler_config(
|
||||
pooling_type=None,
|
||||
pooling_norm=None,
|
||||
pooling_returned_token_ids=None,
|
||||
pooling_softmax=None,
|
||||
pooling_step_tag_id=None)
|
||||
|
||||
assert minilm_pooling_config.pooling_norm
|
||||
assert minilm_pooling_config.pooling_type == PoolingType.MEAN.name
|
||||
|
||||
|
||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||
reason="Xformers backend is not supported on ROCm.")
|
||||
def test_get_pooling_config_from_args():
|
||||
model_id = "sentence-transformers/all-MiniLM-L12-v2"
|
||||
minilm_model_config = ModelConfig(model_id,
|
||||
task="auto",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None)
|
||||
|
||||
minilm_pooling_config = minilm_model_config._init_pooler_config(
|
||||
pooling_type='CLS',
|
||||
pooling_norm=True,
|
||||
pooling_returned_token_ids=None,
|
||||
pooling_softmax=None,
|
||||
pooling_step_tag_id=None)
|
||||
|
||||
assert minilm_pooling_config.pooling_norm
|
||||
assert minilm_pooling_config.pooling_type == PoolingType.CLS.name
|
||||
|
||||
|
||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||
reason="Xformers backend is not supported on ROCm.")
|
||||
def test_get_bert_tokenization_sentence_transformer_config():
|
||||
bge_model_config = ModelConfig(
|
||||
model="BAAI/bge-base-en-v1.5",
|
||||
task="auto",
|
||||
tokenizer="BAAI/bge-base-en-v1.5",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
)
|
||||
|
||||
bert_bge_model_config = bge_model_config._get_encoder_config()
|
||||
|
||||
assert bert_bge_model_config["max_seq_length"] == 512
|
||||
assert bert_bge_model_config["do_lower_case"]
|
||||
|
||||
|
||||
def test_rope_customization():
|
||||
TEST_ROPE_SCALING = {"rope_type": "dynamic", "factor": 2.0}
|
||||
TEST_ROPE_THETA = 16_000_000.0
|
||||
|
||||
Reference in New Issue
Block a user