Support using Int4PreshuffledTensor after loading (#26066)

Signed-off-by: Jerry Zhang <jerryzh168@gmail.com>
This commit is contained in:
Jerry Zhang
2025-11-04 03:00:57 -08:00
committed by GitHub
parent 2ec401bc39
commit 03c4c4aa9d
2 changed files with 208 additions and 4 deletions

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import importlib
import json
import types
from importlib.util import find_spec
from typing import Any, Optional
@@ -27,6 +28,39 @@ from vllm.model_executor.utils import set_weight_attrs
logger = init_logger(__name__)
def _bond_method_to_cls(func, obj):
if hasattr(func, "__self__") or not callable(func):
# If the function is already bound to an instance, return it as is
return func
else:
return types.MethodType(func, obj)
def _get_weight_attrs(param):
# record attributes attached to the weight, so we can
# recover later
recorded_weight_attr = {}
for key in param.__dict__:
if hasattr(param, key):
attr = getattr(param, key)
if not callable(attr):
recorded_weight_attr[key] = attr
elif hasattr(attr, "__self__") and param is attr.__self__:
# if attr is a bonded method for an instance, and
# attr.__self__ points to the instance (param)
# we'll record the underlying function object
recorded_weight_attr[key] = attr.__func__
else:
recorded_weight_attr[key] = attr
return recorded_weight_attr
def _restore_weight_attrs(param, recorded_weight_attr):
for attr_name, attr in recorded_weight_attr.items():
if not hasattr(param, attr_name):
setattr(param, attr_name, _bond_method_to_cls(attr, param))
def torchao_version_at_least(torchao_version: str) -> bool:
if find_spec("torchao"):
try:
@@ -57,6 +91,14 @@ def should_skip(prefix: str, skip_modules: list[str]) -> bool:
return False
if torchao_version_at_least("0.15.0"):
from torchao.prototype.tensor_conversion.api import (
convert_to_packed_tensor_based_on_current_hardware,
)
else:
convert_to_packed_tensor_based_on_current_hardware = lambda t: t
class TorchAOConfig(QuantizationConfig):
"""Config class for torchao."""
@@ -307,12 +349,32 @@ class TorchAOLinearMethod(LinearMethodBase):
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if self.quant_config.is_checkpoint_torchao_serialized:
if not hasattr(layer, "weight"):
return
# record attributes attached to the weight, so we can
# recover later
recorded_weight_attr = _get_weight_attrs(layer.weight)
layer.weight = Parameter(
convert_to_packed_tensor_based_on_current_hardware(layer.weight),
requires_grad=layer.weight.requires_grad,
)
_restore_weight_attrs(layer.weight, recorded_weight_attr)
return
# quantize the weight on the fly if the checkpoint is not already
# online quantize the weight if the checkpoint is not already
# quantized by torchao
recorded_weight_attr = _get_weight_attrs(layer.weight)
weight = torchao_quantize_param_data(
layer.weight, self.quant_config.torchao_config
)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
weight = torch.nn.Parameter(
convert_to_packed_tensor_based_on_current_hardware(weight),
weight.requires_grad,
)
_restore_weight_attrs(weight, recorded_weight_attr)
layer.register_parameter("weight", weight)