[Bug fix][Quantization] Fix dummy weight loading (#38478)

Signed-off-by: Josephasafg <ajgard7@gmail.com>
This commit is contained in:
Asaf Gardin
2026-03-30 23:38:02 +03:00
committed by GitHub
parent d9c7db18da
commit 1fc69f59bb
3 changed files with 19 additions and 13 deletions

View File

@@ -1,13 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import torch.nn as nn
from vllm.config import ModelConfig
from vllm.config.load import LoadConfig
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.reload.meta import materialize_meta_tensor
from vllm.model_executor.model_loader.reload.utils import get_layer_tensors
from vllm.model_executor.model_loader.weight_utils import initialize_dummy_weights
@@ -26,12 +23,6 @@ class DummyModelLoader(BaseModelLoader):
pass # Nothing to download
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
# materialize meta tensors as part of online quantization lifecycle
for layer in model.modules():
for name, param in get_layer_tensors(layer).items():
if param.device == torch.device("meta"):
setattr(layer, name, materialize_meta_tensor(param))
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model, model_config)

View File

@@ -11,7 +11,10 @@ from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention, MLAAttention
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
initialize_single_dummy_weight,
)
from .meta import (
capture_layer_to_meta,
@@ -223,7 +226,8 @@ def finalize_layerwise_processing(model: torch.nn.Module, model_config: ModelCon
elif info.load_numel <= 0:
# first load but received no weights. This happens on dummy load
if info.kernel_tensors is None:
materialize_layer(layer, info)
_layerwise_process(layer, info)
continue
# reloading: place kernel tensors back as a fallback
else:
@@ -258,6 +262,12 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
# Materialize layer tensors onto device
materialize_layer(layer, info)
# If no weights were loaded (e.g. dummy loading), initialize with
# small random values to avoid NaN from zero/garbage data
if len(info.loaded_weights) <= 0:
for tensor in get_layer_tensors(layer).values():
initialize_single_dummy_weight(tensor)
# Reset online quantization flag so process_weights_after_loading
# will run again during reload
if hasattr(layer, "_already_called_process_weights_after_loading"):

View File

@@ -239,10 +239,12 @@ def convert_bin_to_safetensor_file(
sf_size = os.stat(sf_filename).st_size
pt_size = os.stat(pt_filename).st_size
if (sf_size - pt_size) / pt_size > 0.01:
raise RuntimeError(f"""The file size different is more than 1%:
raise RuntimeError(
f"""The file size different is more than 1%:
- {sf_filename}: {sf_size}
- {pt_filename}: {pt_size}
""")
"""
)
# check if the tensors are the same
reloaded = load_file(sf_filename)
@@ -1377,6 +1379,9 @@ def initialize_single_dummy_weight(
high: float = 1e-3,
seed: int = 1234,
) -> None:
if param.device.type == "meta":
return # deferred to finalize_layerwise_processing (e.g. online quant)
if not torch.is_floating_point(param):
if current_platform.is_rocm():
# On ROCm, integer params (e.g. GPTQ qweight/qzeros) are left