[Model] Add user-configurable task for models that support both generation and embedding (#9424)
This commit is contained in:
@@ -2,6 +2,42 @@ import pytest
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model_id", "expected_task"), [
|
||||
("facebook/opt-125m", "generate"),
|
||||
("intfloat/e5-mistral-7b-instruct", "embedding"),
|
||||
])
|
||||
def test_auto_task(model_id, expected_task):
|
||||
config = ModelConfig(
|
||||
model_id,
|
||||
task="auto",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
)
|
||||
|
||||
assert config.task == expected_task
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model_id", "bad_task"), [
|
||||
("facebook/opt-125m", "embedding"),
|
||||
("intfloat/e5-mistral-7b-instruct", "generate"),
|
||||
])
|
||||
def test_incorrect_task(model_id, bad_task):
|
||||
with pytest.raises(ValueError, match=r"does not support the .* task"):
|
||||
ModelConfig(
|
||||
model_id,
|
||||
task=bad_task,
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
)
|
||||
|
||||
|
||||
MODEL_IDS_EXPECTED = [
|
||||
("Qwen/Qwen1.5-7B", 32768),
|
||||
("mistralai/Mistral-7B-v0.1", 4096),
|
||||
@@ -14,7 +50,8 @@ def test_disable_sliding_window(model_id_expected):
|
||||
model_id, expected = model_id_expected
|
||||
model_config = ModelConfig(
|
||||
model_id,
|
||||
model_id,
|
||||
task="auto",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
@@ -32,7 +69,8 @@ def test_get_sliding_window():
|
||||
# when use_sliding_window is False.
|
||||
qwen2_model_config = ModelConfig(
|
||||
"Qwen/Qwen1.5-7B",
|
||||
"Qwen/Qwen1.5-7B",
|
||||
task="auto",
|
||||
tokenizer="Qwen/Qwen1.5-7B",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
@@ -49,7 +87,8 @@ def test_get_sliding_window():
|
||||
|
||||
mistral_model_config = ModelConfig(
|
||||
"mistralai/Mistral-7B-v0.1",
|
||||
"mistralai/Mistral-7B-v0.1",
|
||||
task="auto",
|
||||
tokenizer="mistralai/Mistral-7B-v0.1",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
@@ -70,7 +109,8 @@ def test_rope_customization():
|
||||
|
||||
llama_model_config = ModelConfig(
|
||||
"meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
"meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
task="auto",
|
||||
tokenizer="meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
dtype="float16",
|
||||
@@ -82,7 +122,8 @@ def test_rope_customization():
|
||||
|
||||
llama_model_config = ModelConfig(
|
||||
"meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
"meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
task="auto",
|
||||
tokenizer="meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
dtype="float16",
|
||||
@@ -98,7 +139,8 @@ def test_rope_customization():
|
||||
|
||||
longchat_model_config = ModelConfig(
|
||||
"lmsys/longchat-13b-16k",
|
||||
"lmsys/longchat-13b-16k",
|
||||
task="auto",
|
||||
tokenizer="lmsys/longchat-13b-16k",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
dtype="float16",
|
||||
@@ -112,7 +154,8 @@ def test_rope_customization():
|
||||
|
||||
longchat_model_config = ModelConfig(
|
||||
"lmsys/longchat-13b-16k",
|
||||
"lmsys/longchat-13b-16k",
|
||||
task="auto",
|
||||
tokenizer="lmsys/longchat-13b-16k",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
dtype="float16",
|
||||
|
||||
Reference in New Issue
Block a user