[Quantization] Consolidate dummy format logic into DummyModelLoader (#38637)

Signed-off-by: Josephasafg <ajgard7@gmail.com>
This commit is contained in:
Asaf Gardin
2026-04-01 01:20:45 +03:00
committed by GitHub
parent cc671cb110
commit 3dc01ef352
2 changed files with 42 additions and 15 deletions

View File

@@ -4,8 +4,19 @@ import torch.nn as nn
from vllm.config import ModelConfig
from vllm.config.load import LoadConfig
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.weight_utils import initialize_dummy_weights
from vllm.model_executor.model_loader.reload.layerwise import (
_get_original_loader,
get_layerwise_info,
)
from vllm.model_executor.model_loader.reload.meta import materialize_layer
from vllm.model_executor.model_loader.reload.types import LayerReloadingInfo
from vllm.model_executor.model_loader.reload.utils import get_layer_tensors
from vllm.model_executor.model_loader.weight_utils import (
initialize_dummy_weights,
initialize_single_dummy_weight,
)
class DummyModelLoader(BaseModelLoader):
@@ -23,6 +34,31 @@ class DummyModelLoader(BaseModelLoader):
pass # Nothing to download
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model, model_config)
for layer in model.modules():
info = get_layerwise_info(layer)
if info.can_load():
self._process_online_quant_layer(layer, info)
else:
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(layer, model_config)
def _process_online_quant_layer(
self,
layer: nn.Module,
info: LayerReloadingInfo,
) -> None:
"""Materialize, apply dummy weights, and run quantization processing."""
materialize_layer(layer, info)
for tensor in get_layer_tensors(layer).values():
initialize_single_dummy_weight(tensor)
for param in get_layer_tensors(layer).values():
param.weight_loader = _get_original_loader(param)
quant_method = getattr(layer, "quant_method", None)
if isinstance(quant_method, QuantizeMethodBase):
quant_method.process_weights_after_loading(layer)
info.reset()

View File

@@ -11,10 +11,7 @@ from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention, MLAAttention
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader,
initialize_single_dummy_weight,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from .meta import (
capture_layer_to_meta,
@@ -224,7 +221,7 @@ def finalize_layerwise_processing(model: torch.nn.Module, model_config: ModelCon
# No weights were loaded
elif info.load_numel <= 0:
# first load but received no weights. This happens on dummy load
# first load: checkpoint did not contain weights for this layer
if info.kernel_tensors is None:
_layerwise_process(layer, info)
continue
@@ -262,12 +259,6 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
# Materialize layer tensors onto device
materialize_layer(layer, info)
# If no weights were loaded (e.g. dummy loading), initialize with
# small random values to avoid NaN from zero/garbage data
if len(info.loaded_weights) <= 0:
for tensor in get_layer_tensors(layer).values():
initialize_single_dummy_weight(tensor)
# Reset online quantization flag so process_weights_after_loading
# will run again during reload
if hasattr(layer, "_already_called_process_weights_after_loading"):