Fix AOPerModuleConfig name changes (#18869)

Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
This commit is contained in:
Jerry Zhang
2025-06-05 21:51:32 -04:00
committed by GitHub
parent cb6d572e85
commit c8134bea15
3 changed files with 25 additions and 5 deletions

View File

@@ -6,6 +6,7 @@ import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
@@ -13,12 +14,28 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
class TorchAOConfig(QuantizationConfig):
"""Config class for torchao."""
def __init__(self, torchao_config) -> None:
self.torchao_config = torchao_config
"""
# TorchAO quantization relies on tensor subclasses. In order,
# to enable proper caching this needs standalone compile
if is_torch_equal_or_newer("2.8.0"):
os.environ["VLLM_TEST_STANDALONE_COMPILE"] = "1"
logger.info(
"Using TorchAO: Setting VLLM_TEST_STANDALONE_COMPILE=1")
# TODO: remove after the torch dependency is updated to 2.8
if is_torch_equal_or_newer(
"2.7.0") and not is_torch_equal_or_newer("2.8.0"):
os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1"
logger.info("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1")
"""
def __repr__(self) -> str:
return f"TorchAOConfig({self.torchao_config})"
@@ -61,10 +78,10 @@ class TorchAOConfig(QuantizationConfig):
if not isinstance(layer, LinearBase):
return None
from torchao.quantization import AOPerModuleConfig
from torchao.quantization import ModuleFqnToConfig
module_fqn = prefix
if isinstance(self.torchao_config, AOPerModuleConfig):
if isinstance(self.torchao_config, ModuleFqnToConfig):
module_fqn_to_config = self.torchao_config.module_fqn_to_config
c = module_fqn_to_config.get(
module_fqn) or module_fqn_to_config.get("_default", None)