[Feature] specify model in config.yaml (#15798)

Signed-off-by: weizeng <weizeng@roblox.com>
This commit is contained in:
Wei Zeng
2025-04-01 01:20:06 -07:00
committed by GitHub
parent 8af5a5c4e5
commit 30d6a015e0
7 changed files with 109 additions and 32 deletions

View File

@@ -1241,6 +1241,16 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
if args is None:
args = sys.argv[1:]
# Check for --model in command line arguments first
if args and args[0] == "serve":
model_in_cli_args = any(arg == '--model' for arg in args)
if model_in_cli_args:
raise ValueError(
"With `vllm serve`, you should provide the model as a "
"positional argument or in a config file instead of via "
"the `--model` option.")
if '--config' in args:
args = self._pull_args_from_config(args)
@@ -1324,19 +1334,29 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
config_args = self._load_config_file(file_path)
# 0th index is for {serve,chat,complete}
# followed by model_tag (only for serve)
# optionally followed by model_tag (only for serve)
# followed by config args
# followed by rest of cli args.
# maintaining this order will enforce the precedence
# of cli > config > defaults
if args[0] == "serve":
if index == 1:
model_in_cli = len(args) > 1 and not args[1].startswith('-')
model_in_config = any(arg == '--model' for arg in config_args)
if not model_in_cli and not model_in_config:
raise ValueError(
"No model_tag specified! Please check your command-line"
" arguments.")
args = [args[0]] + [
args[1]
] + config_args + args[2:index] + args[index + 2:]
"No model specified! Please specify model either "
"as a positional argument or in a config file.")
if model_in_cli:
# Model specified as positional arg, keep CLI version
args = [args[0]] + [
args[1]
] + config_args + args[2:index] + args[index + 2:]
else:
# No model in CLI, use config if available
args = [args[0]
] + config_args + args[1:index] + args[index + 2:]
else:
args = [args[0]] + config_args + args[1:index] + args[index + 2:]
@@ -1354,9 +1374,7 @@ class FlexibleArgumentParser(argparse.ArgumentParser):
'--port': '12323',
'--tensor-parallel-size': '4'
]
"""
extension: str = file_path.split('.')[-1]
if extension not in ('yaml', 'yml'):
raise ValueError(