[CLI] Improve CLI arg parsing for -O/--compilation-config (#20156)

Signed-off-by: luka <luka@neuralmagic.com>
This commit is contained in:
Luka Govedič
2025-06-30 21:03:13 -04:00
committed by GitHub
parent ded1fb635b
commit 6d42ce8315
5 changed files with 124 additions and 40 deletions

View File

@@ -239,32 +239,40 @@ def test_compilation_config():
assert args.compilation_config == CompilationConfig()
# set to O3
args = parser.parse_args(["-O3"])
assert args.compilation_config.level == 3
args = parser.parse_args(["-O0"])
assert args.compilation_config.level == 0
# set to O 3 (space)
args = parser.parse_args(["-O", "3"])
assert args.compilation_config.level == 3
args = parser.parse_args(["-O", "1"])
assert args.compilation_config.level == 1
# set to O 3 (equals)
args = parser.parse_args(["-O=3"])
args = parser.parse_args(["-O=2"])
assert args.compilation_config.level == 2
# set to O.level 3
args = parser.parse_args(["-O.level", "3"])
assert args.compilation_config.level == 3
# set to string form of a dict
args = parser.parse_args([
"--compilation-config",
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}',
"-O",
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
'"use_inductor": false}',
])
assert (args.compilation_config.level == 3 and
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8])
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
and not args.compilation_config.use_inductor)
# set to string form of a dict
args = parser.parse_args([
"--compilation-config="
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8]}',
'{"level": 3, "cudagraph_capture_sizes": [1, 2, 4, 8], '
'"use_inductor": true}',
])
assert (args.compilation_config.level == 3 and
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8])
args.compilation_config.cudagraph_capture_sizes == [1, 2, 4, 8]
and args.compilation_config.use_inductor)
def test_prefix_cache_default():