[UX] Enable nested configs in config yaml files (#33193)

This commit is contained in:
Michael Goin
2026-01-28 16:54:25 -05:00
committed by GitHub
parent ab597c869a
commit ca1969186d
2 changed files with 70 additions and 7 deletions

View File

@@ -379,6 +379,50 @@ def test_load_config_file(tmp_path):
os.remove(str(config_file_path))
def test_load_config_file_nested(tmp_path):
"""Test that nested dicts in YAML config are converted to JSON strings."""
config_data = {
"port": 8000,
"compilation-config": {
"pass_config": {"fuse_allreduce_rms": True},
},
}
config_file_path = tmp_path / "nested_config.yaml"
with open(config_file_path, "w") as f:
yaml.dump(config_data, f)
parser = FlexibleArgumentParser()
processed_args = parser.load_config_file(str(config_file_path))
assert processed_args[processed_args.index("--port") + 1] == "8000"
cc_value = json.loads(
processed_args[processed_args.index("--compilation-config") + 1]
)
assert cc_value == {"pass_config": {"fuse_allreduce_rms": True}}
def test_nested_config_end_to_end(tmp_path):
"""Test end-to-end parsing of nested configs in YAML files."""
config_data = {
"compilation-config": {
"mode": 3,
"pass_config": {"fuse_allreduce_rms": True},
},
}
config_file_path = tmp_path / "nested_config.yaml"
with open(config_file_path, "w") as f:
yaml.dump(config_data, f)
parser = FlexibleArgumentParser()
parser.add_argument("-cc", "--compilation-config", type=json.loads)
args = parser.parse_args(["--config", str(config_file_path)])
assert args.compilation_config == {
"mode": 3,
"pass_config": {"fuse_allreduce_rms": True},
}
def test_compilation_mode_string_values(parser):
"""Test that -cc.mode accepts both integer and string mode values."""
args = parser.parse_args(["-cc.mode", "0"])

View File

@@ -444,16 +444,30 @@ class FlexibleArgumentParser(ArgumentParser):
def load_config_file(self, file_path: str) -> list[str]:
"""Loads a yaml file and returns the key value pairs as a
flattened list with argparse like pattern
flattened list with argparse like pattern.
Supports both flat configs and nested YAML structures.
Flat config example:
```yaml
port: 12323
tensor-parallel-size: 4
```
returns:
processed_args: list[str] = [
'--port': '12323',
'--tensor-parallel-size': '4'
]
['--port', '12323', '--tensor-parallel-size', '4']
Nested config example:
```yaml
compilation-config:
pass_config:
fuse_allreduce_rms: true
speculative-config:
model: "nvidia/gpt-oss-120b-Eagle3-v2"
num_speculative_tokens: 3
```
returns:
['--compilation-config', '{"pass_config": {"fuse_allreduce_rms": true}}',
'--speculative-config', '{"model": "nvidia/gpt-oss-120b-Eagle3-v2", ...}']
"""
extension: str = file_path.split(".")[-1]
if extension not in ("yaml", "yml"):
@@ -461,10 +475,10 @@ class FlexibleArgumentParser(ArgumentParser):
f"Config file must be of a yaml/yml type. {extension} supplied"
)
# only expecting a flat dictionary of atomic types
# Supports both flat configs and nested dicts
processed_args: list[str] = []
config: dict[str, int | str] = {}
config: dict[str, Any] = {}
try:
with open(file_path) as config_file:
config = yaml.safe_load(config_file)
@@ -484,6 +498,11 @@ class FlexibleArgumentParser(ArgumentParser):
processed_args.append("--" + key)
for item in value:
processed_args.append(str(item))
elif isinstance(value, dict):
# Convert nested dicts to JSON strings so they can be parsed
# by the existing JSON argument parsing machinery.
processed_args.append("--" + key)
processed_args.append(json.dumps(value))
else:
processed_args.append("--" + key)
processed_args.append(str(value))