[Model] Add user-configurable task for models that support both generation and embedding (#9424)

This commit is contained in:
Cyrus Leung
2024-10-19 02:31:58 +08:00
committed by GitHub
parent 7dbe738d65
commit 051eaf6db3
33 changed files with 451 additions and 201 deletions

View File

@@ -2,6 +2,42 @@ import pytest
from vllm.config import ModelConfig
@pytest.mark.parametrize(("model_id", "expected_task"), [
("facebook/opt-125m", "generate"),
("intfloat/e5-mistral-7b-instruct", "embedding"),
])
def test_auto_task(model_id, expected_task):
config = ModelConfig(
model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
)
assert config.task == expected_task
@pytest.mark.parametrize(("model_id", "bad_task"), [
("facebook/opt-125m", "embedding"),
("intfloat/e5-mistral-7b-instruct", "generate"),
])
def test_incorrect_task(model_id, bad_task):
with pytest.raises(ValueError, match=r"does not support the .* task"):
ModelConfig(
model_id,
task=bad_task,
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
)
MODEL_IDS_EXPECTED = [
("Qwen/Qwen1.5-7B", 32768),
("mistralai/Mistral-7B-v0.1", 4096),
@@ -14,7 +50,8 @@ def test_disable_sliding_window(model_id_expected):
model_id, expected = model_id_expected
model_config = ModelConfig(
model_id,
model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
@@ -32,7 +69,8 @@ def test_get_sliding_window():
# when use_sliding_window is False.
qwen2_model_config = ModelConfig(
"Qwen/Qwen1.5-7B",
"Qwen/Qwen1.5-7B",
task="auto",
tokenizer="Qwen/Qwen1.5-7B",
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
@@ -49,7 +87,8 @@ def test_get_sliding_window():
mistral_model_config = ModelConfig(
"mistralai/Mistral-7B-v0.1",
"mistralai/Mistral-7B-v0.1",
task="auto",
tokenizer="mistralai/Mistral-7B-v0.1",
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
@@ -70,7 +109,8 @@ def test_rope_customization():
llama_model_config = ModelConfig(
"meta-llama/Meta-Llama-3-8B-Instruct",
"meta-llama/Meta-Llama-3-8B-Instruct",
task="auto",
tokenizer="meta-llama/Meta-Llama-3-8B-Instruct",
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float16",
@@ -82,7 +122,8 @@ def test_rope_customization():
llama_model_config = ModelConfig(
"meta-llama/Meta-Llama-3-8B-Instruct",
"meta-llama/Meta-Llama-3-8B-Instruct",
task="auto",
tokenizer="meta-llama/Meta-Llama-3-8B-Instruct",
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float16",
@@ -98,7 +139,8 @@ def test_rope_customization():
longchat_model_config = ModelConfig(
"lmsys/longchat-13b-16k",
"lmsys/longchat-13b-16k",
task="auto",
tokenizer="lmsys/longchat-13b-16k",
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float16",
@@ -112,7 +154,8 @@ def test_rope_customization():
longchat_model_config = ModelConfig(
"lmsys/longchat-13b-16k",
"lmsys/longchat-13b-16k",
task="auto",
tokenizer="lmsys/longchat-13b-16k",
tokenizer_mode="auto",
trust_remote_code=False,
dtype="float16",