Enforce that model is the first positional arg when --served-model-name is used (#34973)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -20,10 +20,22 @@ CHATML_JINJA_PATH = VLLM_PATH / "examples/template_chatml.jinja"
|
||||
assert CHATML_JINJA_PATH.exists()
|
||||
|
||||
|
||||
def _build_vllm_parsers():
|
||||
vllm_parser = FlexibleArgumentParser()
|
||||
subparsers = vllm_parser.add_subparsers()
|
||||
serve_parser = subparsers.add_parser("serve")
|
||||
make_arg_parser(serve_parser)
|
||||
return {"vllm": vllm_parser, "vllm serve": serve_parser}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vllm_parser():
|
||||
return _build_vllm_parsers()["vllm"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def serve_parser():
|
||||
parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.")
|
||||
return make_arg_parser(parser)
|
||||
return _build_vllm_parsers()["vllm serve"]
|
||||
|
||||
|
||||
### Test config parsing
|
||||
@@ -241,3 +253,41 @@ def test_default_chat_template_kwargs_invalid_json(serve_parser):
|
||||
serve_parser.parse_args(
|
||||
args=["--default-chat-template-kwargs", "not valid json"]
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"args, raises",
|
||||
[
|
||||
(["user/model"], None),
|
||||
(["user/model", "--served-model-name", "model"], None),
|
||||
(["--served-model-name", "model", "user/model"], ValueError),
|
||||
(["--served-model-name", "model", "--config", "config.yaml"], None),
|
||||
(["--served-model-name", "model", "--config", "config.yaml"], ValueError),
|
||||
],
|
||||
ids=[
|
||||
"model_tag_only",
|
||||
"model_tag_with_served_model_name",
|
||||
"served_model_name_before_model_tag",
|
||||
"served_model_name_with_model_in_config",
|
||||
"served_model_name_with_no_model_in_config",
|
||||
],
|
||||
)
|
||||
def test_served_model_name_parsing(tmp_path, vllm_parser, args, raises):
|
||||
"""Ensure that users don't misuse --served-model-name and end up with the default
|
||||
model tag instead of the one they intended to serve."""
|
||||
# Call the serve subparser
|
||||
args.insert(0, "serve")
|
||||
# Create a dummy config file if the test case includes it
|
||||
if "config.yaml" in args:
|
||||
# Create a dummy config file if the test case includes it
|
||||
config_path = tmp_path / "config.yaml"
|
||||
config_path.write_text("model: user/model" if raises is None else "port: 8000")
|
||||
args[args.index("config.yaml")] = config_path.as_posix()
|
||||
# Do the parsing and check for expected exceptions or values
|
||||
if raises is None:
|
||||
parsed_args = vllm_parser.parse_args(args=args)
|
||||
expected = "user/model"
|
||||
assert parsed_args.model_tag == expected or parsed_args.model == expected
|
||||
else:
|
||||
with pytest.raises(raises):
|
||||
vllm_parser.parse_args(args=args)
|
||||
|
||||
Reference in New Issue
Block a user