[Deprecation][2/N] Replace --task with --runner and --convert (#21470)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -74,115 +74,116 @@ def test_update_config():
|
||||
new_config3 = update_config(config3, {"a": "new_value"})
|
||||
|
||||
|
||||
# Can remove once --task option is fully deprecated
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "expected_runner_type", "expected_task"),
|
||||
("model_id", "expected_runner_type", "expected_convert_type",
|
||||
"expected_task"),
|
||||
[
|
||||
("distilbert/distilgpt2", "generate", "generate"),
|
||||
("intfloat/multilingual-e5-small", "pooling", "embed"),
|
||||
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"),
|
||||
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "classify"),
|
||||
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "reward"),
|
||||
("openai/whisper-small", "generate", "transcription"),
|
||||
("distilbert/distilgpt2", "generate", "none", "generate"),
|
||||
("intfloat/multilingual-e5-small", "pooling", "none", "embed"),
|
||||
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify", "classify"),
|
||||
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "none",
|
||||
"classify"),
|
||||
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "none", "reward"),
|
||||
("openai/whisper-small", "generate", "none", "transcription"),
|
||||
],
|
||||
)
|
||||
def test_auto_task(model_id, expected_runner_type, expected_task):
|
||||
config = ModelConfig(
|
||||
model_id,
|
||||
task="auto",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
)
|
||||
def test_auto_task(model_id, expected_runner_type, expected_convert_type,
|
||||
expected_task):
|
||||
config = ModelConfig(model_id, task="auto")
|
||||
|
||||
assert config.runner_type == expected_runner_type
|
||||
assert config.convert_type == expected_convert_type
|
||||
assert expected_task in config.supported_tasks
|
||||
|
||||
if config.runner_type == "pooling":
|
||||
assert config.task == expected_task
|
||||
else:
|
||||
assert expected_task in config.supported_tasks
|
||||
|
||||
# Can remove once --task option is fully deprecated
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "expected_runner_type", "expected_convert_type",
|
||||
"expected_task"),
|
||||
[
|
||||
("distilbert/distilgpt2", "pooling", "embed", "embed"),
|
||||
("intfloat/multilingual-e5-small", "pooling", "embed", "embed"),
|
||||
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify", "classify"),
|
||||
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "classify",
|
||||
"classify"),
|
||||
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "embed", "embed"),
|
||||
("openai/whisper-small", "pooling", "embed", "embed"),
|
||||
],
|
||||
)
|
||||
def test_score_task(model_id, expected_runner_type, expected_convert_type,
|
||||
expected_task):
|
||||
config = ModelConfig(model_id, task="score")
|
||||
|
||||
assert config.runner_type == expected_runner_type
|
||||
assert config.convert_type == expected_convert_type
|
||||
assert expected_task in config.supported_tasks
|
||||
|
||||
|
||||
# Can remove once --task option is fully deprecated
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "expected_runner_type", "expected_convert_type",
|
||||
"expected_task"),
|
||||
[
|
||||
("openai/whisper-small", "generate", "none", "transcription"),
|
||||
],
|
||||
)
|
||||
def test_transcription_task(model_id, expected_runner_type,
|
||||
expected_convert_type, expected_task):
|
||||
config = ModelConfig(model_id, task="transcription")
|
||||
|
||||
assert config.runner_type == expected_runner_type
|
||||
assert config.convert_type == expected_convert_type
|
||||
assert expected_task in config.supported_tasks
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "expected_runner_type", "expected_task"),
|
||||
("model_id", "expected_runner_type", "expected_convert_type"),
|
||||
[
|
||||
("distilbert/distilgpt2", "generate", "none"),
|
||||
("intfloat/multilingual-e5-small", "pooling", "none"),
|
||||
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"),
|
||||
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "none"),
|
||||
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "none"),
|
||||
("openai/whisper-small", "generate", "none"),
|
||||
],
|
||||
)
|
||||
def test_auto_runner(model_id, expected_runner_type, expected_convert_type):
|
||||
config = ModelConfig(model_id, runner="auto")
|
||||
|
||||
assert config.runner_type == expected_runner_type
|
||||
assert config.convert_type == expected_convert_type
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "expected_runner_type", "expected_convert_type"),
|
||||
[
|
||||
("distilbert/distilgpt2", "pooling", "embed"),
|
||||
("intfloat/multilingual-e5-small", "pooling", "embed"),
|
||||
("intfloat/multilingual-e5-small", "pooling", "none"),
|
||||
("jason9693/Qwen2.5-1.5B-apeach", "pooling", "classify"),
|
||||
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "classify"),
|
||||
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "embed"),
|
||||
("cross-encoder/ms-marco-MiniLM-L-6-v2", "pooling", "none"),
|
||||
("Qwen/Qwen2.5-Math-RM-72B", "pooling", "none"),
|
||||
("openai/whisper-small", "pooling", "embed"),
|
||||
],
|
||||
)
|
||||
def test_score_task(model_id, expected_runner_type, expected_task):
|
||||
config = ModelConfig(
|
||||
model_id,
|
||||
task="score",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
)
|
||||
def test_pooling_runner(model_id, expected_runner_type, expected_convert_type):
|
||||
config = ModelConfig(model_id, runner="pooling")
|
||||
|
||||
assert config.runner_type == expected_runner_type
|
||||
assert config.task == expected_task
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model_id", "expected_runner_type", "expected_task"),
|
||||
[
|
||||
("Qwen/Qwen2.5-1.5B-Instruct", "draft", "auto"),
|
||||
])
|
||||
def test_draft_task(model_id, expected_runner_type, expected_task):
|
||||
config = ModelConfig(
|
||||
model_id,
|
||||
runner="draft",
|
||||
tokenizer=model_id,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
)
|
||||
|
||||
assert config.runner_type == expected_runner_type
|
||||
assert config.task == expected_task
|
||||
assert config.convert_type == expected_convert_type
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "expected_runner_type", "expected_task"),
|
||||
("model_id", "expected_runner_type", "expected_convert_type"),
|
||||
[
|
||||
("openai/whisper-small", "generate", "transcription"),
|
||||
("Qwen/Qwen2.5-1.5B-Instruct", "draft", "none"),
|
||||
],
|
||||
)
|
||||
def test_transcription_task(model_id, expected_runner_type, expected_task):
|
||||
config = ModelConfig(
|
||||
model_id,
|
||||
task="transcription",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
)
|
||||
def test_draft_runner(model_id, expected_runner_type, expected_convert_type):
|
||||
config = ModelConfig(model_id, runner="draft")
|
||||
|
||||
assert config.runner_type == expected_runner_type
|
||||
assert config.task == expected_task
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("model_id", "bad_task"), [
|
||||
("Qwen/Qwen2.5-Math-RM-72B", "generate"),
|
||||
("Qwen/Qwen3-0.6B", "transcription"),
|
||||
])
|
||||
def test_incorrect_task(model_id, bad_task):
|
||||
with pytest.raises(ValueError, match=r"does not support task=.*"):
|
||||
ModelConfig(
|
||||
model_id,
|
||||
task=bad_task,
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
)
|
||||
assert config.convert_type == expected_convert_type
|
||||
|
||||
|
||||
MODEL_IDS_EXPECTED = [
|
||||
@@ -195,17 +196,7 @@ MODEL_IDS_EXPECTED = [
|
||||
@pytest.mark.parametrize("model_id_expected", MODEL_IDS_EXPECTED)
|
||||
def test_disable_sliding_window(model_id_expected):
|
||||
model_id, expected = model_id_expected
|
||||
model_config = ModelConfig(
|
||||
model_id,
|
||||
task="auto",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
disable_sliding_window=True,
|
||||
)
|
||||
model_config = ModelConfig(model_id, disable_sliding_window=True)
|
||||
assert model_config.max_model_len == expected
|
||||
|
||||
|
||||
@@ -214,16 +205,7 @@ def test_get_sliding_window():
|
||||
# Test that the sliding window is correctly computed.
|
||||
# For Qwen1.5/Qwen2, get_sliding_window() should be None
|
||||
# when use_sliding_window is False.
|
||||
qwen2_model_config = ModelConfig(
|
||||
"Qwen/Qwen1.5-7B",
|
||||
task="auto",
|
||||
tokenizer="Qwen/Qwen1.5-7B",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
)
|
||||
qwen2_model_config = ModelConfig("Qwen/Qwen1.5-7B")
|
||||
|
||||
qwen2_model_config.hf_config.use_sliding_window = False
|
||||
qwen2_model_config.hf_config.sliding_window = TEST_SLIDING_WINDOW
|
||||
@@ -232,16 +214,7 @@ def test_get_sliding_window():
|
||||
qwen2_model_config.hf_config.use_sliding_window = True
|
||||
assert qwen2_model_config.get_sliding_window() == TEST_SLIDING_WINDOW
|
||||
|
||||
mistral_model_config = ModelConfig(
|
||||
"mistralai/Mistral-7B-v0.1",
|
||||
task="auto",
|
||||
tokenizer="mistralai/Mistral-7B-v0.1",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
)
|
||||
mistral_model_config = ModelConfig("mistralai/Mistral-7B-v0.1")
|
||||
mistral_model_config.hf_config.sliding_window = None
|
||||
assert mistral_model_config.get_sliding_window() is None
|
||||
|
||||
@@ -253,16 +226,7 @@ def test_get_sliding_window():
|
||||
reason="Xformers backend is not supported on ROCm.")
|
||||
def test_get_pooling_config():
|
||||
model_id = "sentence-transformers/all-MiniLM-L12-v2"
|
||||
model_config = ModelConfig(
|
||||
model_id,
|
||||
task="auto",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
)
|
||||
model_config = ModelConfig(model_id)
|
||||
|
||||
pooling_config = model_config._init_pooler_config()
|
||||
assert pooling_config is not None
|
||||
@@ -275,14 +239,7 @@ def test_get_pooling_config():
|
||||
reason="Xformers backend is not supported on ROCm.")
|
||||
def test_get_pooling_config_from_args():
|
||||
model_id = "sentence-transformers/all-MiniLM-L12-v2"
|
||||
model_config = ModelConfig(model_id,
|
||||
task="auto",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None)
|
||||
model_config = ModelConfig(model_id)
|
||||
|
||||
override_pooler_config = PoolerConfig(pooling_type='CLS', normalize=True)
|
||||
model_config.override_pooler_config = override_pooler_config
|
||||
@@ -295,16 +252,8 @@ def test_get_pooling_config_from_args():
|
||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||
reason="Xformers backend is not supported on ROCm.")
|
||||
def test_get_bert_tokenization_sentence_transformer_config():
|
||||
bge_model_config = ModelConfig(
|
||||
model="BAAI/bge-base-en-v1.5",
|
||||
task="auto",
|
||||
tokenizer="BAAI/bge-base-en-v1.5",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
)
|
||||
model_id = "BAAI/bge-base-en-v1.5"
|
||||
bge_model_config = ModelConfig(model_id)
|
||||
|
||||
bert_bge_model_config = bge_model_config._get_encoder_config()
|
||||
|
||||
@@ -317,27 +266,13 @@ def test_rope_customization():
|
||||
TEST_ROPE_THETA = 16_000_000.0
|
||||
LONGCHAT_ROPE_SCALING = {"rope_type": "linear", "factor": 8.0}
|
||||
|
||||
llama_model_config = ModelConfig(
|
||||
"meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
task="auto",
|
||||
tokenizer="meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
dtype="float16",
|
||||
seed=0,
|
||||
)
|
||||
llama_model_config = ModelConfig("meta-llama/Meta-Llama-3-8B-Instruct")
|
||||
assert getattr(llama_model_config.hf_config, "rope_scaling", None) is None
|
||||
assert getattr(llama_model_config.hf_config, "rope_theta", None) == 500_000
|
||||
assert llama_model_config.max_model_len == 8192
|
||||
|
||||
llama_model_config = ModelConfig(
|
||||
"meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
task="auto",
|
||||
tokenizer="meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
dtype="float16",
|
||||
seed=0,
|
||||
hf_overrides={
|
||||
"rope_scaling": TEST_ROPE_SCALING,
|
||||
"rope_theta": TEST_ROPE_THETA,
|
||||
@@ -349,15 +284,7 @@ def test_rope_customization():
|
||||
None) == TEST_ROPE_THETA
|
||||
assert llama_model_config.max_model_len == 16384
|
||||
|
||||
longchat_model_config = ModelConfig(
|
||||
"lmsys/longchat-13b-16k",
|
||||
task="auto",
|
||||
tokenizer="lmsys/longchat-13b-16k",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
dtype="float16",
|
||||
seed=0,
|
||||
)
|
||||
longchat_model_config = ModelConfig("lmsys/longchat-13b-16k")
|
||||
# Check if LONGCHAT_ROPE_SCALING entries are in longchat_model_config
|
||||
assert all(
|
||||
longchat_model_config.hf_config.rope_scaling.get(key) == value
|
||||
@@ -366,12 +293,6 @@ def test_rope_customization():
|
||||
|
||||
longchat_model_config = ModelConfig(
|
||||
"lmsys/longchat-13b-16k",
|
||||
task="auto",
|
||||
tokenizer="lmsys/longchat-13b-16k",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
dtype="float16",
|
||||
seed=0,
|
||||
hf_overrides={
|
||||
"rope_scaling": TEST_ROPE_SCALING,
|
||||
},
|
||||
@@ -390,15 +311,7 @@ def test_rope_customization():
|
||||
("meta-llama/Llama-3.2-11B-Vision", True),
|
||||
])
|
||||
def test_is_encoder_decoder(model_id, is_encoder_decoder):
|
||||
config = ModelConfig(
|
||||
model_id,
|
||||
task="auto",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
dtype="float16",
|
||||
seed=0,
|
||||
)
|
||||
config = ModelConfig(model_id)
|
||||
|
||||
assert config.is_encoder_decoder == is_encoder_decoder
|
||||
|
||||
@@ -408,15 +321,7 @@ def test_is_encoder_decoder(model_id, is_encoder_decoder):
|
||||
("Qwen/Qwen2-VL-2B-Instruct", True),
|
||||
])
|
||||
def test_uses_mrope(model_id, uses_mrope):
|
||||
config = ModelConfig(
|
||||
model_id,
|
||||
task="auto",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
dtype="float16",
|
||||
seed=0,
|
||||
)
|
||||
config = ModelConfig(model_id)
|
||||
|
||||
assert config.uses_mrope == uses_mrope
|
||||
|
||||
@@ -426,26 +331,12 @@ def test_generation_config_loading():
|
||||
|
||||
# When set generation_config to "vllm", the default generation config
|
||||
# will not be loaded.
|
||||
model_config = ModelConfig(model_id,
|
||||
task="auto",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
generation_config="vllm")
|
||||
model_config = ModelConfig(model_id, generation_config="vllm")
|
||||
assert model_config.get_diff_sampling_param() == {}
|
||||
|
||||
# When set generation_config to "auto", the default generation config
|
||||
# should be loaded.
|
||||
model_config = ModelConfig(model_id,
|
||||
task="auto",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
generation_config="auto")
|
||||
model_config = ModelConfig(model_id, generation_config="auto")
|
||||
|
||||
correct_generation_config = {
|
||||
"repetition_penalty": 1.1,
|
||||
@@ -461,12 +352,6 @@ def test_generation_config_loading():
|
||||
|
||||
model_config = ModelConfig(
|
||||
model_id,
|
||||
task="auto",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
generation_config="auto",
|
||||
override_generation_config=override_generation_config)
|
||||
|
||||
@@ -479,12 +364,6 @@ def test_generation_config_loading():
|
||||
# is set, the override_generation_config should be used directly.
|
||||
model_config = ModelConfig(
|
||||
model_id,
|
||||
task="auto",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
generation_config="vllm",
|
||||
override_generation_config=override_generation_config)
|
||||
|
||||
@@ -515,16 +394,7 @@ def test_load_config_pt_load_map_location(pt_load_map_location):
|
||||
def test_get_and_verify_max_len(model_id, max_model_len, expected_max_len,
|
||||
should_raise):
|
||||
"""Test get_and_verify_max_len with different configurations."""
|
||||
model_config = ModelConfig(
|
||||
model_id,
|
||||
task="auto",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
)
|
||||
model_config = ModelConfig(model_id)
|
||||
|
||||
if should_raise:
|
||||
with pytest.raises(ValueError):
|
||||
|
||||
Reference in New Issue
Block a user