Adds method to read the pooling types from model's files (#9506)

Signed-off-by: Flavia Beo <flavia.beo@ibm.com>
Signed-off-by: Max de Bayser <mbayser@br.ibm.com>
Co-authored-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
Flávia Béo
2024-11-07 05:42:40 -03:00
committed by GitHub
parent e036e527a0
commit aa9078fa03
10 changed files with 342 additions and 25 deletions

View File

@@ -1,6 +1,8 @@
import pytest
from vllm.config import ModelConfig
from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform
@pytest.mark.parametrize(("model_id", "expected_task"), [
@@ -102,6 +104,76 @@ def test_get_sliding_window():
assert mistral_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
def test_get_pooling_config():
model_id = "sentence-transformers/all-MiniLM-L12-v2"
minilm_model_config = ModelConfig(
model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
revision=None,
)
minilm_pooling_config = minilm_model_config._init_pooler_config(
pooling_type=None,
pooling_norm=None,
pooling_returned_token_ids=None,
pooling_softmax=None,
pooling_step_tag_id=None)
assert minilm_pooling_config.pooling_norm
assert minilm_pooling_config.pooling_type == PoolingType.MEAN.name
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
def test_get_pooling_config_from_args():
model_id = "sentence-transformers/all-MiniLM-L12-v2"
minilm_model_config = ModelConfig(model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
revision=None)
minilm_pooling_config = minilm_model_config._init_pooler_config(
pooling_type='CLS',
pooling_norm=True,
pooling_returned_token_ids=None,
pooling_softmax=None,
pooling_step_tag_id=None)
assert minilm_pooling_config.pooling_norm
assert minilm_pooling_config.pooling_type == PoolingType.CLS.name
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Xformers backend is not supported on ROCm.")
def test_get_bert_tokenization_sentence_transformer_config():
bge_model_config = ModelConfig(
model="BAAI/bge-base-en-v1.5",
task="auto",
tokenizer="BAAI/bge-base-en-v1.5",
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
revision=None,
)
bert_bge_model_config = bge_model_config._get_encoder_config()
assert bert_bge_model_config["max_seq_length"] == 512
assert bert_bge_model_config["do_lower_case"]
def test_rope_customization():
TEST_ROPE_SCALING = {"rope_type": "dynamic", "factor": 2.0}
TEST_ROPE_THETA = 16_000_000.0