[Quantization] Add field to skip unquantized modules for GPTQ config (#25455)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2025-09-26 23:47:41 +08:00
committed by GitHub
parent db1e42f627
commit d4d9899860
16 changed files with 219 additions and 153 deletions

View File

@@ -4,6 +4,7 @@
import json
import os
import time
from dataclasses import asdict
from functools import cache, partial
from pathlib import Path
from typing import Any, Callable, Literal, Optional, TypeVar, Union
@@ -27,7 +28,8 @@ from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME
from vllm import envs
from vllm.logger import init_logger
from vllm.transformers_utils.config_parser_base import ConfigParserBase
from vllm.transformers_utils.utils import check_gguf_file
from vllm.transformers_utils.utils import (check_gguf_file,
parse_safetensors_file_metadata)
if envs.VLLM_USE_MODELSCOPE:
from modelscope import AutoConfig
@@ -999,6 +1001,34 @@ def try_get_tokenizer_config(
return None
def get_safetensors_params_metadata(
model: str,
*,
revision: Optional[str] = None,
) -> dict[str, Any]:
"""
Get the safetensors metadata for remote model repository.
"""
full_metadata = {}
if (model_path := Path(model)).exists():
safetensors_to_check = model_path.glob("*.safetensors")
full_metadata = {
param_name: info
for file_path in safetensors_to_check if file_path.is_file()
for param_name, info in parse_safetensors_file_metadata(
file_path).items()
}
else:
repo_mt = try_get_safetensors_metadata(model, revision=revision)
if repo_mt and (files_mt := repo_mt.files_metadata):
full_metadata = {
param_name: asdict(info)
for file_mt in files_mt.values()
for param_name, info in file_mt.tensors.items()
}
return full_metadata
def _download_mistral_config_file(model, revision) -> dict:
config_file_name = "params.json"
config_dict = get_hf_file_to_dict(config_file_name, model, revision)

View File

@@ -2,10 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import struct
from functools import cache
from os import PathLike
from pathlib import Path
from typing import Optional, Union
from typing import Any, Optional, Union
from vllm.envs import VLLM_MODEL_REDIRECT_PATH
from vllm.logger import init_logger
@@ -97,3 +98,11 @@ def maybe_model_redirect(model: str) -> str:
return redirect_model
return model
def parse_safetensors_file_metadata(
path: Union[str, PathLike]) -> dict[str, Any]:
with open(path, "rb") as f:
length_of_metadata = struct.unpack('<Q', f.read(8))[0]
metadata = json.loads(f.read(length_of_metadata).decode('utf-8'))
return metadata