[Core] Support multiple tasks per model (#20771)

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Nicolò Lucchesi
2025-07-13 04:40:11 +02:00
committed by GitHub
parent c1acd6d7d4
commit 020f58abcd
8 changed files with 279 additions and 148 deletions

View File

@@ -54,7 +54,7 @@ def test_get_field():
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"),
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "classify"),
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "reward"),
("openai/whisper-small", "transcription", "transcription"),
("openai/whisper-small", "generate", "transcription"),
],
)
def test_auto_task(model_id, expected_runner_type, expected_task):
@@ -69,7 +69,11 @@ def test_auto_task(model_id, expected_runner_type, expected_task):
)
assert config.runner_type == expected_runner_type
assert config.task == expected_task
if config.runner_type == "pooling":
assert config.task == expected_task
else:
assert expected_task in config.supported_tasks
@pytest.mark.parametrize(
@@ -98,11 +102,50 @@ def test_score_task(model_id, expected_runner_type, expected_task):
assert config.task == expected_task
@pytest.mark.parametrize(("model_id", "expected_runner_type", "expected_task"),
[
("Qwen/Qwen2.5-1.5B-Instruct", "draft", "auto"),
])
def test_draft_task(model_id, expected_runner_type, expected_task):
config = ModelConfig(
model_id,
runner="draft",
tokenizer=model_id,
seed=0,
dtype="float16",
)
assert config.runner_type == expected_runner_type
assert config.task == expected_task
@pytest.mark.parametrize(
("model_id", "expected_runner_type", "expected_task"),
[
("openai/whisper-small", "generate", "transcription"),
],
)
def test_transcription_task(model_id, expected_runner_type, expected_task):
config = ModelConfig(
model_id,
task="transcription",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
)
assert config.runner_type == expected_runner_type
assert config.task == expected_task
@pytest.mark.parametrize(("model_id", "bad_task"), [
("Qwen/Qwen2.5-Math-RM-72B", "generate"),
("Qwen/Qwen3-0.6B", "transcription"),
])
def test_incorrect_task(model_id, bad_task):
with pytest.raises(ValueError, match=r"does not support the .* task"):
with pytest.raises(ValueError, match=r"does not support task=.*"):
ModelConfig(
model_id,
task=bad_task,