[Bug fix][Quantization] Fix dummy weight loading (#38478)
Signed-off-by: Josephasafg <ajgard7@gmail.com>
This commit is contained in:
@@ -1,13 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.reload.meta import materialize_meta_tensor
|
||||
from vllm.model_executor.model_loader.reload.utils import get_layer_tensors
|
||||
from vllm.model_executor.model_loader.weight_utils import initialize_dummy_weights
|
||||
|
||||
|
||||
@@ -26,12 +23,6 @@ class DummyModelLoader(BaseModelLoader):
|
||||
pass # Nothing to download
|
||||
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
# materialize meta tensors as part of online quantization lifecycle
|
||||
for layer in model.modules():
|
||||
for name, param in get_layer_tensors(layer).items():
|
||||
if param.device == torch.device("meta"):
|
||||
setattr(layer, name, materialize_meta_tensor(param))
|
||||
|
||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||
# random values to the weights.
|
||||
initialize_dummy_weights(model, model_config)
|
||||
|
||||
@@ -11,7 +11,10 @@ from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention, MLAAttention
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
initialize_single_dummy_weight,
|
||||
)
|
||||
|
||||
from .meta import (
|
||||
capture_layer_to_meta,
|
||||
@@ -223,7 +226,8 @@ def finalize_layerwise_processing(model: torch.nn.Module, model_config: ModelCon
|
||||
elif info.load_numel <= 0:
|
||||
# first load but received no weights. This happens on dummy load
|
||||
if info.kernel_tensors is None:
|
||||
materialize_layer(layer, info)
|
||||
_layerwise_process(layer, info)
|
||||
continue
|
||||
|
||||
# reloading: place kernel tensors back as a fallback
|
||||
else:
|
||||
@@ -258,6 +262,12 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
|
||||
# Materialize layer tensors onto device
|
||||
materialize_layer(layer, info)
|
||||
|
||||
# If no weights were loaded (e.g. dummy loading), initialize with
|
||||
# small random values to avoid NaN from zero/garbage data
|
||||
if len(info.loaded_weights) <= 0:
|
||||
for tensor in get_layer_tensors(layer).values():
|
||||
initialize_single_dummy_weight(tensor)
|
||||
|
||||
# Reset online quantization flag so process_weights_after_loading
|
||||
# will run again during reload
|
||||
if hasattr(layer, "_already_called_process_weights_after_loading"):
|
||||
|
||||
@@ -239,10 +239,12 @@ def convert_bin_to_safetensor_file(
|
||||
sf_size = os.stat(sf_filename).st_size
|
||||
pt_size = os.stat(pt_filename).st_size
|
||||
if (sf_size - pt_size) / pt_size > 0.01:
|
||||
raise RuntimeError(f"""The file size different is more than 1%:
|
||||
raise RuntimeError(
|
||||
f"""The file size different is more than 1%:
|
||||
- {sf_filename}: {sf_size}
|
||||
- {pt_filename}: {pt_size}
|
||||
""")
|
||||
"""
|
||||
)
|
||||
|
||||
# check if the tensors are the same
|
||||
reloaded = load_file(sf_filename)
|
||||
@@ -1377,6 +1379,9 @@ def initialize_single_dummy_weight(
|
||||
high: float = 1e-3,
|
||||
seed: int = 1234,
|
||||
) -> None:
|
||||
if param.device.type == "meta":
|
||||
return # deferred to finalize_layerwise_processing (e.g. online quant)
|
||||
|
||||
if not torch.is_floating_point(param):
|
||||
if current_platform.is_rocm():
|
||||
# On ROCm, integer params (e.g. GPTQ qweight/qzeros) are left
|
||||
|
||||
Reference in New Issue
Block a user