[Quantization] Consolidate dummy format logic into DummyModelLoader (#38637)
Signed-off-by: Josephasafg <ajgard7@gmail.com>
This commit is contained in:
@@ -4,8 +4,19 @@ import torch.nn as nn
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import initialize_dummy_weights
|
||||
from vllm.model_executor.model_loader.reload.layerwise import (
|
||||
_get_original_loader,
|
||||
get_layerwise_info,
|
||||
)
|
||||
from vllm.model_executor.model_loader.reload.meta import materialize_layer
|
||||
from vllm.model_executor.model_loader.reload.types import LayerReloadingInfo
|
||||
from vllm.model_executor.model_loader.reload.utils import get_layer_tensors
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
initialize_dummy_weights,
|
||||
initialize_single_dummy_weight,
|
||||
)
|
||||
|
||||
|
||||
class DummyModelLoader(BaseModelLoader):
|
||||
@@ -23,6 +34,31 @@ class DummyModelLoader(BaseModelLoader):
|
||||
pass # Nothing to download
|
||||
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||
# random values to the weights.
|
||||
initialize_dummy_weights(model, model_config)
|
||||
for layer in model.modules():
|
||||
info = get_layerwise_info(layer)
|
||||
if info.can_load():
|
||||
self._process_online_quant_layer(layer, info)
|
||||
else:
|
||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||
# random values to the weights.
|
||||
initialize_dummy_weights(layer, model_config)
|
||||
|
||||
def _process_online_quant_layer(
|
||||
self,
|
||||
layer: nn.Module,
|
||||
info: LayerReloadingInfo,
|
||||
) -> None:
|
||||
"""Materialize, apply dummy weights, and run quantization processing."""
|
||||
materialize_layer(layer, info)
|
||||
|
||||
for tensor in get_layer_tensors(layer).values():
|
||||
initialize_single_dummy_weight(tensor)
|
||||
|
||||
for param in get_layer_tensors(layer).values():
|
||||
param.weight_loader = _get_original_loader(param)
|
||||
|
||||
quant_method = getattr(layer, "quant_method", None)
|
||||
if isinstance(quant_method, QuantizeMethodBase):
|
||||
quant_method.process_weights_after_loading(layer)
|
||||
|
||||
info.reset()
|
||||
|
||||
@@ -11,10 +11,7 @@ from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention, MLAAttention
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
initialize_single_dummy_weight,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from .meta import (
|
||||
capture_layer_to_meta,
|
||||
@@ -224,7 +221,7 @@ def finalize_layerwise_processing(model: torch.nn.Module, model_config: ModelCon
|
||||
|
||||
# No weights were loaded
|
||||
elif info.load_numel <= 0:
|
||||
# first load but received no weights. This happens on dummy load
|
||||
# first load: checkpoint did not contain weights for this layer
|
||||
if info.kernel_tensors is None:
|
||||
_layerwise_process(layer, info)
|
||||
continue
|
||||
@@ -262,12 +259,6 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
|
||||
# Materialize layer tensors onto device
|
||||
materialize_layer(layer, info)
|
||||
|
||||
# If no weights were loaded (e.g. dummy loading), initialize with
|
||||
# small random values to avoid NaN from zero/garbage data
|
||||
if len(info.loaded_weights) <= 0:
|
||||
for tensor in get_layer_tensors(layer).values():
|
||||
initialize_single_dummy_weight(tensor)
|
||||
|
||||
# Reset online quantization flag so process_weights_after_loading
|
||||
# will run again during reload
|
||||
if hasattr(layer, "_already_called_process_weights_after_loading"):
|
||||
|
||||
Reference in New Issue
Block a user