diff --git a/tests/utils_/test_argparse_utils.py b/tests/utils_/test_argparse_utils.py index fbc278404..53639d02b 100644 --- a/tests/utils_/test_argparse_utils.py +++ b/tests/utils_/test_argparse_utils.py @@ -379,6 +379,50 @@ def test_load_config_file(tmp_path): os.remove(str(config_file_path)) +def test_load_config_file_nested(tmp_path): + """Test that nested dicts in YAML config are converted to JSON strings.""" + config_data = { + "port": 8000, + "compilation-config": { + "pass_config": {"fuse_allreduce_rms": True}, + }, + } + config_file_path = tmp_path / "nested_config.yaml" + with open(config_file_path, "w") as f: + yaml.dump(config_data, f) + + parser = FlexibleArgumentParser() + processed_args = parser.load_config_file(str(config_file_path)) + + assert processed_args[processed_args.index("--port") + 1] == "8000" + cc_value = json.loads( + processed_args[processed_args.index("--compilation-config") + 1] + ) + assert cc_value == {"pass_config": {"fuse_allreduce_rms": True}} + + +def test_nested_config_end_to_end(tmp_path): + """Test end-to-end parsing of nested configs in YAML files.""" + config_data = { + "compilation-config": { + "mode": 3, + "pass_config": {"fuse_allreduce_rms": True}, + }, + } + config_file_path = tmp_path / "nested_config.yaml" + with open(config_file_path, "w") as f: + yaml.dump(config_data, f) + + parser = FlexibleArgumentParser() + parser.add_argument("-cc", "--compilation-config", type=json.loads) + args = parser.parse_args(["--config", str(config_file_path)]) + + assert args.compilation_config == { + "mode": 3, + "pass_config": {"fuse_allreduce_rms": True}, + } + + def test_compilation_mode_string_values(parser): """Test that -cc.mode accepts both integer and string mode values.""" args = parser.parse_args(["-cc.mode", "0"]) diff --git a/vllm/utils/argparse_utils.py b/vllm/utils/argparse_utils.py index 9c2cec876..d88f2fa6f 100644 --- a/vllm/utils/argparse_utils.py +++ b/vllm/utils/argparse_utils.py @@ -444,16 +444,30 @@ class FlexibleArgumentParser(ArgumentParser): def load_config_file(self, file_path: str) -> list[str]: """Loads a yaml file and returns the key value pairs as a - flattened list with argparse like pattern + flattened list with argparse like pattern. + + Supports both flat configs and nested YAML structures. + + Flat config example: ```yaml port: 12323 tensor-parallel-size: 4 ``` returns: - processed_args: list[str] = [ - '--port': '12323', - '--tensor-parallel-size': '4' - ] + ['--port', '12323', '--tensor-parallel-size', '4'] + + Nested config example: + ```yaml + compilation-config: + pass_config: + fuse_allreduce_rms: true + speculative-config: + model: "nvidia/gpt-oss-120b-Eagle3-v2" + num_speculative_tokens: 3 + ``` + returns: + ['--compilation-config', '{"pass_config": {"fuse_allreduce_rms": true}}', + '--speculative-config', '{"model": "nvidia/gpt-oss-120b-Eagle3-v2", ...}'] """ extension: str = file_path.split(".")[-1] if extension not in ("yaml", "yml"): @@ -461,10 +475,10 @@ class FlexibleArgumentParser(ArgumentParser): f"Config file must be of a yaml/yml type. {extension} supplied" ) - # only expecting a flat dictionary of atomic types + # Supports both flat configs and nested dicts processed_args: list[str] = [] - config: dict[str, int | str] = {} + config: dict[str, Any] = {} try: with open(file_path) as config_file: config = yaml.safe_load(config_file) @@ -484,6 +498,11 @@ class FlexibleArgumentParser(ArgumentParser): processed_args.append("--" + key) for item in value: processed_args.append(str(item)) + elif isinstance(value, dict): + # Convert nested dicts to JSON strings so they can be parsed + # by the existing JSON argument parsing machinery. + processed_args.append("--" + key) + processed_args.append(json.dumps(value)) else: processed_args.append("--" + key) processed_args.append(str(value))