diff --git a/vllm/model_executor/model_loader/dummy_loader.py b/vllm/model_executor/model_loader/dummy_loader.py index 156071f1d..98f82a5b7 100644 --- a/vllm/model_executor/model_loader/dummy_loader.py +++ b/vllm/model_executor/model_loader/dummy_loader.py @@ -4,8 +4,19 @@ import torch.nn as nn from vllm.config import ModelConfig from vllm.config.load import LoadConfig +from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase from vllm.model_executor.model_loader.base_loader import BaseModelLoader -from vllm.model_executor.model_loader.weight_utils import initialize_dummy_weights +from vllm.model_executor.model_loader.reload.layerwise import ( + _get_original_loader, + get_layerwise_info, +) +from vllm.model_executor.model_loader.reload.meta import materialize_layer +from vllm.model_executor.model_loader.reload.types import LayerReloadingInfo +from vllm.model_executor.model_loader.reload.utils import get_layer_tensors +from vllm.model_executor.model_loader.weight_utils import ( + initialize_dummy_weights, + initialize_single_dummy_weight, +) class DummyModelLoader(BaseModelLoader): @@ -23,6 +34,31 @@ class DummyModelLoader(BaseModelLoader): pass # Nothing to download def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: - # NOTE(woosuk): For accurate performance evaluation, we assign - # random values to the weights. - initialize_dummy_weights(model, model_config) + for layer in model.modules(): + info = get_layerwise_info(layer) + if info.can_load(): + self._process_online_quant_layer(layer, info) + else: + # NOTE(woosuk): For accurate performance evaluation, we assign + # random values to the weights. + initialize_dummy_weights(layer, model_config) + + def _process_online_quant_layer( + self, + layer: nn.Module, + info: LayerReloadingInfo, + ) -> None: + """Materialize, apply dummy weights, and run quantization processing.""" + materialize_layer(layer, info) + + for tensor in get_layer_tensors(layer).values(): + initialize_single_dummy_weight(tensor) + + for param in get_layer_tensors(layer).values(): + param.weight_loader = _get_original_loader(param) + + quant_method = getattr(layer, "quant_method", None) + if isinstance(quant_method, QuantizeMethodBase): + quant_method.process_weights_after_loading(layer) + + info.reset() diff --git a/vllm/model_executor/model_loader/reload/layerwise.py b/vllm/model_executor/model_loader/reload/layerwise.py index f7b1c13bb..5d4af7d1f 100644 --- a/vllm/model_executor/model_loader/reload/layerwise.py +++ b/vllm/model_executor/model_loader/reload/layerwise.py @@ -11,10 +11,7 @@ from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.model_executor.layers.attention import Attention, MLAAttention from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, - initialize_single_dummy_weight, -) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from .meta import ( capture_layer_to_meta, @@ -224,7 +221,7 @@ def finalize_layerwise_processing(model: torch.nn.Module, model_config: ModelCon # No weights were loaded elif info.load_numel <= 0: - # first load but received no weights. This happens on dummy load + # first load: checkpoint did not contain weights for this layer if info.kernel_tensors is None: _layerwise_process(layer, info) continue @@ -262,12 +259,6 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo): # Materialize layer tensors onto device materialize_layer(layer, info) - # If no weights were loaded (e.g. dummy loading), initialize with - # small random values to avoid NaN from zero/garbage data - if len(info.loaded_weights) <= 0: - for tensor in get_layer_tensors(layer).values(): - initialize_single_dummy_weight(tensor) - # Reset online quantization flag so process_weights_after_loading # will run again during reload if hasattr(layer, "_already_called_process_weights_after_loading"):