[Frontend][torch.compile] CompilationConfig Overhaul (#20283): Set up -O infrastructure (#26847)

Signed-off-by: morrison-turnansky <mturnans@redhat.com>
Signed-off-by: adabeyta <aabeyta@redhat.com>
Signed-off-by: Morrison Turnansky <mturnans@redhat.com>
Co-authored-by: adabeyta <aabeyta@redhat.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Morrison Turnansky
2025-11-27 04:55:58 -05:00
committed by GitHub
parent 00d3310d2d
commit 0838b52e2e
13 changed files with 735 additions and 64 deletions

View File

@@ -247,16 +247,16 @@ class FlexibleArgumentParser(ArgumentParser):
elif arg.startswith("-O") and arg != "-O" and arg[2] != ".":
# allow -O flag to be used without space, e.g. -O3 or -Odecode
# -O.<...> handled later
# also handle -O=<mode> here
mode = arg[3:] if arg[2] == "=" else arg[2:]
processed_args.append(f"-O.mode={mode}")
# also handle -O=<optimization_level> here
optimization_level = arg[3:] if arg[2] == "=" else arg[2:]
processed_args += ["--optimization-level", optimization_level]
elif (
arg == "-O"
and i + 1 < len(args)
and args[i + 1] in {"0", "1", "2", "3"}
):
# Convert -O <n> to -O.mode <n>
processed_args.append("-O.mode")
# Convert -O <n> to --optimization-level <n>
processed_args.append("--optimization-level")
else:
processed_args.append(arg)
@@ -294,10 +294,24 @@ class FlexibleArgumentParser(ArgumentParser):
delete = set[int]()
dict_args = defaultdict[str, dict[str, Any]](dict)
duplicates = set[str]()
# Track regular arguments (non-dict args) for duplicate detection
regular_args_seen = set[str]()
for i, processed_arg in enumerate(processed_args):
if i in delete: # skip if value from previous arg
continue
if processed_arg.startswith("--") and "." not in processed_arg:
if "=" in processed_arg:
arg_name = processed_arg.split("=", 1)[0]
else:
arg_name = processed_arg
if arg_name in regular_args_seen:
duplicates.add(arg_name)
else:
regular_args_seen.add(arg_name)
continue
if processed_arg.startswith("-") and "." in processed_arg:
if "=" in processed_arg:
processed_arg, value_str = processed_arg.split("=", 1)