[torchao] safetensors integration (#25969)
Signed-off-by: Angel Li <liangel@meta.com>
This commit is contained in:
@@ -14,6 +14,7 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.torchao import torchao_version_at_least
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_safetensors_index_file_from_hf,
|
||||
@@ -272,6 +273,10 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
)
|
||||
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
if model_config.quantization == "torchao" and torchao_version_at_least(
|
||||
"0.14.0"
|
||||
):
|
||||
self.load_config.safetensors_load_strategy = "torchao"
|
||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||
|
||||
# if we don't have `model.weight_metadata_and_attr_saved` defined and
|
||||
|
||||
Reference in New Issue
Block a user