[UX] Enable nested configs in config yaml files (#33193)
This commit is contained in:
@@ -379,6 +379,50 @@ def test_load_config_file(tmp_path):
|
||||
os.remove(str(config_file_path))
|
||||
|
||||
|
||||
def test_load_config_file_nested(tmp_path):
|
||||
"""Test that nested dicts in YAML config are converted to JSON strings."""
|
||||
config_data = {
|
||||
"port": 8000,
|
||||
"compilation-config": {
|
||||
"pass_config": {"fuse_allreduce_rms": True},
|
||||
},
|
||||
}
|
||||
config_file_path = tmp_path / "nested_config.yaml"
|
||||
with open(config_file_path, "w") as f:
|
||||
yaml.dump(config_data, f)
|
||||
|
||||
parser = FlexibleArgumentParser()
|
||||
processed_args = parser.load_config_file(str(config_file_path))
|
||||
|
||||
assert processed_args[processed_args.index("--port") + 1] == "8000"
|
||||
cc_value = json.loads(
|
||||
processed_args[processed_args.index("--compilation-config") + 1]
|
||||
)
|
||||
assert cc_value == {"pass_config": {"fuse_allreduce_rms": True}}
|
||||
|
||||
|
||||
def test_nested_config_end_to_end(tmp_path):
|
||||
"""Test end-to-end parsing of nested configs in YAML files."""
|
||||
config_data = {
|
||||
"compilation-config": {
|
||||
"mode": 3,
|
||||
"pass_config": {"fuse_allreduce_rms": True},
|
||||
},
|
||||
}
|
||||
config_file_path = tmp_path / "nested_config.yaml"
|
||||
with open(config_file_path, "w") as f:
|
||||
yaml.dump(config_data, f)
|
||||
|
||||
parser = FlexibleArgumentParser()
|
||||
parser.add_argument("-cc", "--compilation-config", type=json.loads)
|
||||
args = parser.parse_args(["--config", str(config_file_path)])
|
||||
|
||||
assert args.compilation_config == {
|
||||
"mode": 3,
|
||||
"pass_config": {"fuse_allreduce_rms": True},
|
||||
}
|
||||
|
||||
|
||||
def test_compilation_mode_string_values(parser):
|
||||
"""Test that -cc.mode accepts both integer and string mode values."""
|
||||
args = parser.parse_args(["-cc.mode", "0"])
|
||||
|
||||
@@ -444,16 +444,30 @@ class FlexibleArgumentParser(ArgumentParser):
|
||||
|
||||
def load_config_file(self, file_path: str) -> list[str]:
|
||||
"""Loads a yaml file and returns the key value pairs as a
|
||||
flattened list with argparse like pattern
|
||||
flattened list with argparse like pattern.
|
||||
|
||||
Supports both flat configs and nested YAML structures.
|
||||
|
||||
Flat config example:
|
||||
```yaml
|
||||
port: 12323
|
||||
tensor-parallel-size: 4
|
||||
```
|
||||
returns:
|
||||
processed_args: list[str] = [
|
||||
'--port': '12323',
|
||||
'--tensor-parallel-size': '4'
|
||||
]
|
||||
['--port', '12323', '--tensor-parallel-size', '4']
|
||||
|
||||
Nested config example:
|
||||
```yaml
|
||||
compilation-config:
|
||||
pass_config:
|
||||
fuse_allreduce_rms: true
|
||||
speculative-config:
|
||||
model: "nvidia/gpt-oss-120b-Eagle3-v2"
|
||||
num_speculative_tokens: 3
|
||||
```
|
||||
returns:
|
||||
['--compilation-config', '{"pass_config": {"fuse_allreduce_rms": true}}',
|
||||
'--speculative-config', '{"model": "nvidia/gpt-oss-120b-Eagle3-v2", ...}']
|
||||
"""
|
||||
extension: str = file_path.split(".")[-1]
|
||||
if extension not in ("yaml", "yml"):
|
||||
@@ -461,10 +475,10 @@ class FlexibleArgumentParser(ArgumentParser):
|
||||
f"Config file must be of a yaml/yml type. {extension} supplied"
|
||||
)
|
||||
|
||||
# only expecting a flat dictionary of atomic types
|
||||
# Supports both flat configs and nested dicts
|
||||
processed_args: list[str] = []
|
||||
|
||||
config: dict[str, int | str] = {}
|
||||
config: dict[str, Any] = {}
|
||||
try:
|
||||
with open(file_path) as config_file:
|
||||
config = yaml.safe_load(config_file)
|
||||
@@ -484,6 +498,11 @@ class FlexibleArgumentParser(ArgumentParser):
|
||||
processed_args.append("--" + key)
|
||||
for item in value:
|
||||
processed_args.append(str(item))
|
||||
elif isinstance(value, dict):
|
||||
# Convert nested dicts to JSON strings so they can be parsed
|
||||
# by the existing JSON argument parsing machinery.
|
||||
processed_args.append("--" + key)
|
||||
processed_args.append(json.dumps(value))
|
||||
else:
|
||||
processed_args.append("--" + key)
|
||||
processed_args.append(str(value))
|
||||
|
||||
Reference in New Issue
Block a user