Fix AOPerModuleConfig name changes (#18869)
Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
This commit is contained in:
@@ -6,6 +6,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
@@ -13,12 +14,28 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TorchAOConfig(QuantizationConfig):
|
||||
"""Config class for torchao."""
|
||||
|
||||
def __init__(self, torchao_config) -> None:
|
||||
self.torchao_config = torchao_config
|
||||
"""
|
||||
# TorchAO quantization relies on tensor subclasses. In order,
|
||||
# to enable proper caching this needs standalone compile
|
||||
if is_torch_equal_or_newer("2.8.0"):
|
||||
os.environ["VLLM_TEST_STANDALONE_COMPILE"] = "1"
|
||||
logger.info(
|
||||
"Using TorchAO: Setting VLLM_TEST_STANDALONE_COMPILE=1")
|
||||
|
||||
# TODO: remove after the torch dependency is updated to 2.8
|
||||
if is_torch_equal_or_newer(
|
||||
"2.7.0") and not is_torch_equal_or_newer("2.8.0"):
|
||||
os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1"
|
||||
logger.info("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1")
|
||||
"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"TorchAOConfig({self.torchao_config})"
|
||||
@@ -61,10 +78,10 @@ class TorchAOConfig(QuantizationConfig):
|
||||
if not isinstance(layer, LinearBase):
|
||||
return None
|
||||
|
||||
from torchao.quantization import AOPerModuleConfig
|
||||
from torchao.quantization import ModuleFqnToConfig
|
||||
|
||||
module_fqn = prefix
|
||||
if isinstance(self.torchao_config, AOPerModuleConfig):
|
||||
if isinstance(self.torchao_config, ModuleFqnToConfig):
|
||||
module_fqn_to_config = self.torchao_config.module_fqn_to_config
|
||||
c = module_fqn_to_config.get(
|
||||
module_fqn) or module_fqn_to_config.get("_default", None)
|
||||
|
||||
Reference in New Issue
Block a user