diff --git a/tests/entrypoints/openai/test_cli_args.py b/tests/entrypoints/openai/test_cli_args.py index dd5d62990..ccf145a0c 100644 --- a/tests/entrypoints/openai/test_cli_args.py +++ b/tests/entrypoints/openai/test_cli_args.py @@ -20,10 +20,22 @@ CHATML_JINJA_PATH = VLLM_PATH / "examples/template_chatml.jinja" assert CHATML_JINJA_PATH.exists() +def _build_vllm_parsers(): + vllm_parser = FlexibleArgumentParser() + subparsers = vllm_parser.add_subparsers() + serve_parser = subparsers.add_parser("serve") + make_arg_parser(serve_parser) + return {"vllm": vllm_parser, "vllm serve": serve_parser} + + +@pytest.fixture +def vllm_parser(): + return _build_vllm_parsers()["vllm"] + + @pytest.fixture def serve_parser(): - parser = FlexibleArgumentParser(description="vLLM's remote OpenAI server.") - return make_arg_parser(parser) + return _build_vllm_parsers()["vllm serve"] ### Test config parsing @@ -241,3 +253,41 @@ def test_default_chat_template_kwargs_invalid_json(serve_parser): serve_parser.parse_args( args=["--default-chat-template-kwargs", "not valid json"] ) + + +@pytest.mark.parametrize( + "args, raises", + [ + (["user/model"], None), + (["user/model", "--served-model-name", "model"], None), + (["--served-model-name", "model", "user/model"], ValueError), + (["--served-model-name", "model", "--config", "config.yaml"], None), + (["--served-model-name", "model", "--config", "config.yaml"], ValueError), + ], + ids=[ + "model_tag_only", + "model_tag_with_served_model_name", + "served_model_name_before_model_tag", + "served_model_name_with_model_in_config", + "served_model_name_with_no_model_in_config", + ], +) +def test_served_model_name_parsing(tmp_path, vllm_parser, args, raises): + """Ensure that users don't misuse --served-model-name and end up with the default + model tag instead of the one they intended to serve.""" + # Call the serve subparser + args.insert(0, "serve") + # Create a dummy config file if the test case includes it + if "config.yaml" in args: + # Create a dummy config file if the test case includes it + config_path = tmp_path / "config.yaml" + config_path.write_text("model: user/model" if raises is None else "port: 8000") + args[args.index("config.yaml")] = config_path.as_posix() + # Do the parsing and check for expected exceptions or values + if raises is None: + parsed_args = vllm_parser.parse_args(args=args) + expected = "user/model" + assert parsed_args.model_tag == expected or parsed_args.model == expected + else: + with pytest.raises(raises): + vllm_parser.parse_args(args=args) diff --git a/vllm/utils/argparse_utils.py b/vllm/utils/argparse_utils.py index d88f2fa6f..e4482d4fb 100644 --- a/vllm/utils/argparse_utils.py +++ b/vllm/utils/argparse_utils.py @@ -184,13 +184,11 @@ class FlexibleArgumentParser(ArgumentParser): if args is None: args = sys.argv[1:] - # Check for --model in command line arguments first if args and args[0] == "serve": + # Check for --model in command line arguments first try: model_idx = next( - i - for i, arg in enumerate(args) - if arg == "--model" or arg.startswith("--model=") + i for i, arg in enumerate(args) if re.match(r"^--model(=.+|$)", arg) ) logger.warning( "With `vllm serve`, you should provide the model as a " @@ -219,6 +217,19 @@ class FlexibleArgumentParser(ArgumentParser): ] except StopIteration: pass + # Check for --served-model-name without a positional model argument + if ( + len(args) > 1 + and args[1].startswith("-") + and not any(re.match(r"^--config(=.+|$)", arg) for arg in args) + and any( + re.match(r"^--served[-_]model[-_]name(=.+|$)", arg) for arg in args + ) + ): + raise ValueError( + "`model` should be provided as the first positional argument when " + "using `vllm serve`. i.e. `vllm serve -- `." + ) if "--config" in args: args = self._pull_args_from_config(args)