[torchao] safetensors integration (#25969)

Signed-off-by: Angel Li <liangel@meta.com>
This commit is contained in:
liangel-02
2025-10-07 19:12:35 -07:00
committed by GitHub
parent f80e7866c0
commit b32260ab85
5 changed files with 60 additions and 0 deletions

View File

@@ -14,6 +14,7 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import ModelConfig
from vllm.config.load import LoadConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.torchao import torchao_version_at_least
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf,
@@ -272,6 +273,10 @@ class DefaultModelLoader(BaseModelLoader):
)
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
if model_config.quantization == "torchao" and torchao_version_at_least(
"0.14.0"
):
self.load_config.safetensors_load_strategy = "torchao"
weights_to_load = {name for name, _ in model.named_parameters()}
# if we don't have `model.weight_metadata_and_attr_saved` defined and