[torchao] Add support for ModuleFqnToConfig using regex (#26001)

Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
This commit is contained in:
Jerry Zhang
2025-10-09 01:32:32 -07:00
committed by GitHub
parent cf4cd6c24f
commit a83ff278d6
2 changed files with 38 additions and 3 deletions

View File

@@ -5,6 +5,7 @@ import json
from importlib.util import find_spec
from typing import Any, Optional
import regex as re
import torch
import torch.nn.functional as F
from packaging import version
@@ -192,9 +193,26 @@ class TorchAOConfig(QuantizationConfig):
module_fqn = prefix
if isinstance(self.torchao_config, ModuleFqnToConfig):
module_fqn_to_config = self.torchao_config.module_fqn_to_config
c = module_fqn_to_config.get(module_fqn) or module_fqn_to_config.get(
"_default", None
)
c = None
if module_fqn in module_fqn_to_config:
assert not module_fqn.startswith("re:"), (
"module fqn should not start with"
"`re:`, which is used for specifying regex"
)
c = module_fqn_to_config[module_fqn]
else:
for maybe_module_fqn_pattern in module_fqn_to_config:
if not maybe_module_fqn_pattern.startswith("re:"):
continue
elif re.fullmatch(maybe_module_fqn_pattern[3:], module_fqn):
# we'll apply the config for first fully matched pattern
c = module_fqn_to_config[maybe_module_fqn_pattern]
break
else:
# fallback to use default if no module specific
# config is provided
c = module_fqn_to_config.get("_default", None)
if c is not None:
current_torchao_config = TorchAOConfig(
c, self.skip_modules, self.is_checkpoint_torchao_serialized