[QeRL] Compose online quantization with quantized reloading (#38032)
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
@@ -73,7 +73,9 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
cutlass_fp8_supported,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import initialize_single_dummy_weight
|
||||
from vllm.model_executor.model_loader.reload.layerwise import (
|
||||
initialize_online_processing,
|
||||
)
|
||||
from vllm.model_executor.parameter import (
|
||||
BlockQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
@@ -496,8 +498,8 @@ class Fp8LinearMethod(LinearMethodBase):
|
||||
|
||||
|
||||
class Fp8OnlineLinearMethod(Fp8LinearMethod):
|
||||
"""Online version of Fp8LinearMethod, loads the fp16/bf16 checkpoint
|
||||
and quantized the weights during loading."""
|
||||
"""Online version of Fp8LinearMethod which loads a full precision checkpoint
|
||||
and quantizes weights during loading."""
|
||||
|
||||
uses_meta_device: bool = True
|
||||
|
||||
@@ -519,84 +521,25 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.weight_block_size = None
|
||||
|
||||
# WEIGHT
|
||||
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
|
||||
# track how many elements we have updated
|
||||
if not hasattr(layer, "_loaded_numel"):
|
||||
layer._loaded_numel = 0
|
||||
|
||||
# when the first `loaded_weight` is about to be
|
||||
# loaded to `param`, materialize `param` just-in-time
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty_like(layer.weight, device=layer._load_device),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=patched_weight_loader,
|
||||
)
|
||||
_copy_missing_attrs(layer.weight, weight)
|
||||
layer.register_parameter("weight", weight)
|
||||
del layer._load_device
|
||||
|
||||
# refresh the reference to `param` to reflect just-in-time
|
||||
# materialization
|
||||
param = layer.weight
|
||||
|
||||
# load the current weight chunk
|
||||
copy_numel_counter = CopyNumelCounter()
|
||||
with copy_numel_counter:
|
||||
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
|
||||
layer._loaded_numel += copy_numel_counter.copied_numel
|
||||
|
||||
# if we have loaded all of the elements, call
|
||||
# process_weights_after_loading
|
||||
target_loaded_numel = layer.weight.numel()
|
||||
if layer._loaded_numel == target_loaded_numel:
|
||||
self.process_weights_after_loading(layer)
|
||||
|
||||
# Prevent the usual `process_weights_after_loading` call from doing
|
||||
# anything
|
||||
layer._already_called_process_weights_after_loading = True
|
||||
|
||||
# Note that we keep `layer._loaded_numel` around just in case
|
||||
# there is logic added to vllm in the future which calls a
|
||||
# weight loader twice - we do not want to re-initialize in
|
||||
# that case.
|
||||
|
||||
return res
|
||||
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
# materialized just-in-time in `patched_weight_loader`
|
||||
device="meta",
|
||||
device="meta", # materialized and processed during loading
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=patched_weight_loader,
|
||||
weight_loader=weight_loader,
|
||||
)
|
||||
# stash the correct device for `patched_weight_loader`
|
||||
layer._load_device = torch.get_default_device()
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
initialize_online_processing(layer)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||
return
|
||||
|
||||
# deferred initialization of randomly initialized weights for the
|
||||
# `--load_format dummy` feature
|
||||
if layer.weight.device == torch.device("meta"):
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty_like(layer.weight, device=layer._load_device),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=layer.weight.weight_loader,
|
||||
)
|
||||
_copy_missing_attrs(layer.weight, weight)
|
||||
layer.register_parameter("weight", weight)
|
||||
initialize_single_dummy_weight(layer.weight)
|
||||
|
||||
# TODO(future): support block_quant in online quant path
|
||||
assert not self.block_quant
|
||||
|
||||
@@ -845,9 +788,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||
return
|
||||
|
||||
# Allow for accessing weights and scales in standard way.
|
||||
w13 = layer.w13_weight
|
||||
w2 = layer.w2_weight
|
||||
@@ -892,9 +832,6 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
layer, w13, w2, w13_scale, w2_scale, w13_input_scale, w2_input_scale
|
||||
)
|
||||
|
||||
# Prevent duplicate processing (e.g., during weight reload)
|
||||
layer._already_called_process_weights_after_loading = True
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
@@ -1013,86 +950,12 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
layer.orig_dtype = params_dtype
|
||||
layer.weight_block_size = None
|
||||
|
||||
# We are doing online quantization, patch the weight loaded
|
||||
# to call `process_weights_after_loading` in a streaming fashion
|
||||
# as soon as the last weight chunk is loaded.
|
||||
weight_loader = extra_weight_attrs["weight_loader"]
|
||||
# create a new holder to prevent modifying behavior of any other
|
||||
# objects which might depend on the old one
|
||||
new_extra_weight_attrs = extra_weight_attrs
|
||||
|
||||
def patched_weight_loader(param, loaded_weight, *args, **kwargs):
|
||||
# add a counter to track how many elements we have updated
|
||||
if not hasattr(layer, "_loaded_numel"):
|
||||
layer._loaded_numel = 0
|
||||
|
||||
# save the ids of original w13 and w2 so that we can
|
||||
# distinguish which one `param` should map to further
|
||||
# down in this file
|
||||
layer._w13_weight_orig_id = id(layer.w13_weight)
|
||||
layer._w2_weight_orig_id = id(layer.w2_weight)
|
||||
|
||||
# when the first `loaded_weight` is about to be
|
||||
# loaded to `param`, materialize `param` just-in-time
|
||||
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty_like(layer.w13_weight, device=layer._load_device),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
_copy_missing_attrs(layer.w13_weight, w13_weight)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty_like(layer.w2_weight, device=layer._load_device),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
_copy_missing_attrs(layer.w2_weight, w2_weight)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
del layer._load_device
|
||||
|
||||
# refresh the reference to `param` to reflect just-in-time
|
||||
# materialization
|
||||
if id(param) == layer._w13_weight_orig_id:
|
||||
param = layer.w13_weight
|
||||
elif id(param) == layer._w2_weight_orig_id:
|
||||
param = layer.w2_weight
|
||||
|
||||
# load the current weight chunk
|
||||
copy_numel_counter = CopyNumelCounter()
|
||||
with copy_numel_counter:
|
||||
res = weight_loader(param, loaded_weight, *args, **kwargs) # type: ignore[misc]
|
||||
layer._loaded_numel += copy_numel_counter.copied_numel
|
||||
|
||||
# if we have loaded all of the elements, call
|
||||
# process_weights_after_loading
|
||||
target_loaded_numel = layer.w13_weight.numel() + layer.w2_weight.numel()
|
||||
if layer._loaded_numel == target_loaded_numel:
|
||||
self.process_weights_after_loading(layer)
|
||||
|
||||
# Prevent the usual `process_weights_after_loading` call
|
||||
# from doing anything
|
||||
layer._already_called_process_weights_after_loading = True
|
||||
|
||||
# Note that we keep `layer._loaded_numel`,
|
||||
# `layer._w13_weight_orig_id` and `layer._w2_weight_orig_id`
|
||||
# around because if EP is on, weight loaders for non-local
|
||||
# experts will run but not actually copy any elements, and we
|
||||
# need to not re-initialize in that case.
|
||||
|
||||
return res
|
||||
|
||||
new_extra_weight_attrs["weight_loader"] = patched_weight_loader
|
||||
extra_weight_attrs = new_extra_weight_attrs
|
||||
|
||||
# WEIGHTS
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
# materialized just-in-time in `patched_weight_loader`
|
||||
device="meta",
|
||||
dtype=params_dtype,
|
||||
),
|
||||
@@ -1106,91 +969,53 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
# materialized just-in-time in `patched_weight_loader`
|
||||
device="meta",
|
||||
device="meta", # materialized and processed during loading
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
# stash the correct device for `patched_weight_loader`
|
||||
layer._load_device = torch.get_default_device()
|
||||
|
||||
# BIASES (for models like GPT-OSS that have biased MoE)
|
||||
if self.moe.has_bias:
|
||||
# Use the original weight_loader (not patched) for biases
|
||||
orig_extra_weight_attrs = dict(extra_weight_attrs)
|
||||
orig_extra_weight_attrs["weight_loader"] = weight_loader
|
||||
w13_bias = torch.nn.Parameter(
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
device="meta", # materialized and processed during loading
|
||||
dtype=layer.orig_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w13_bias", w13_bias)
|
||||
set_weight_attrs(w13_bias, orig_extra_weight_attrs)
|
||||
set_weight_attrs(w13_bias, extra_weight_attrs)
|
||||
|
||||
w2_bias = torch.nn.Parameter(
|
||||
torch.zeros(num_experts, hidden_size, dtype=layer.orig_dtype),
|
||||
torch.zeros(
|
||||
num_experts,
|
||||
hidden_size,
|
||||
device="meta", # materialized and processed during loading
|
||||
dtype=layer.orig_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_bias", w2_bias)
|
||||
set_weight_attrs(w2_bias, orig_extra_weight_attrs)
|
||||
set_weight_attrs(w2_bias, extra_weight_attrs)
|
||||
|
||||
# WEIGHT_SCALES
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
# They will be combined to a single scale after weight loading.
|
||||
w13_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
w2_weight_scale = torch.nn.Parameter(
|
||||
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
||||
)
|
||||
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
||||
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
||||
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
||||
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
||||
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
initialize_online_processing(layer)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||
return
|
||||
|
||||
# deferred initialization of randomly initialized weights for the
|
||||
# `--load_format dummy` feature
|
||||
if layer.w13_weight.device == torch.device("meta"):
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty_like(layer.w13_weight, device=layer._load_device),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
w13_weight, {"weight_loader": layer.w13_weight.weight_loader}
|
||||
)
|
||||
_copy_missing_attrs(layer.w13_weight, w13_weight)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
initialize_single_dummy_weight(layer.w13_weight)
|
||||
if layer.w2_weight.device == torch.device("meta"):
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty_like(layer.w2_weight, device=layer._load_device),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
w2_weight, {"weight_loader": layer.w2_weight.weight_loader}
|
||||
)
|
||||
_copy_missing_attrs(layer.w2_weight, w2_weight)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
initialize_single_dummy_weight(layer.w2_weight)
|
||||
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
|
||||
w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
|
||||
w13_scale = layer.w13_weight_scale
|
||||
w2_scale = layer.w2_weight_scale
|
||||
w13_scale = torch.ones(layer.num_experts, dtype=torch.float32)
|
||||
w2_scale = torch.ones(layer.num_experts, dtype=torch.float32)
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
for expert in range(layer.local_num_experts):
|
||||
w13[expert, :, :], w13_scale[expert] = ops.scaled_fp8_quant(
|
||||
@@ -1207,8 +1032,8 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
w2,
|
||||
w13_scale,
|
||||
w2_scale,
|
||||
layer.w13_input_scale,
|
||||
layer.w2_input_scale,
|
||||
w13_input_scale=layer.w13_input_scale,
|
||||
w2_input_scale=layer.w2_input_scale,
|
||||
)
|
||||
|
||||
# Prevent duplicate processing (e.g., during weight reload)
|
||||
|
||||
@@ -337,6 +337,8 @@ class Mxfp8OnlineMoEMethod(Fp8OnlineMoEMethod):
|
||||
w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
|
||||
w13_scale = layer.w13_weight_scale
|
||||
w2_scale = layer.w2_weight_scale
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
w13, w13_scale = self._quantize_mxfp8_moe_weight(layer.w13_weight)
|
||||
w2, w2_scale = self._quantize_mxfp8_moe_weight(layer.w2_weight)
|
||||
|
||||
@@ -9,6 +9,7 @@ import vllm.envs as envs
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.reload import finalize_layerwise_processing
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
initialize_model,
|
||||
process_weights_after_loading,
|
||||
@@ -49,16 +50,13 @@ class BaseModelLoader(ABC):
|
||||
device_config.device if load_config.device is None else load_config.device
|
||||
)
|
||||
target_device = torch.device(load_device)
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
model = initialize_model(
|
||||
vllm_config=vllm_config, model_config=model_config, prefix=prefix
|
||||
)
|
||||
|
||||
with set_default_torch_dtype(model_config.dtype), target_device:
|
||||
model = initialize_model(
|
||||
vllm_config=vllm_config, model_config=model_config, prefix=prefix
|
||||
)
|
||||
log_model_inspection(model)
|
||||
|
||||
logger.debug("Loading weights on %s ...", load_device)
|
||||
# Quantization does not happen in `load_weights` but after it
|
||||
self.load_weights(model, model_config)
|
||||
|
||||
# Log peak GPU memory after loading weights. This is needed
|
||||
@@ -71,6 +69,11 @@ class BaseModelLoader(ABC):
|
||||
scope="local",
|
||||
)
|
||||
|
||||
# Process weights into kernel format. Note that when using online
|
||||
# quantization, weights are (typically) quantized as they are loaded.
|
||||
if _has_online_quant(model):
|
||||
finalize_layerwise_processing(model, model_config)
|
||||
|
||||
process_weights_after_loading(model, model_config, target_device)
|
||||
|
||||
return model.eval()
|
||||
@@ -84,3 +87,12 @@ def log_model_inspection(model: nn.Module) -> None:
|
||||
from vllm.model_inspection import format_model_inspection
|
||||
|
||||
logger.info("vLLM model structure:\n%s", format_model_inspection(model))
|
||||
|
||||
|
||||
def _has_online_quant(model: nn.Module):
|
||||
for module in model.modules():
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if getattr(quant_method, "uses_meta_device", False):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.reload.meta import materialize_meta_tensor
|
||||
from vllm.model_executor.model_loader.reload.utils import get_layer_tensors
|
||||
from vllm.model_executor.model_loader.weight_utils import initialize_dummy_weights
|
||||
|
||||
|
||||
@@ -23,6 +26,12 @@ class DummyModelLoader(BaseModelLoader):
|
||||
pass # Nothing to download
|
||||
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
# materialize meta tensors as part of online quantization lifecycle
|
||||
for layer in model.modules():
|
||||
for name, param in get_layer_tensors(layer).items():
|
||||
if param.device == torch.device("meta"):
|
||||
setattr(layer, name, materialize_meta_tensor(param))
|
||||
|
||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||
# random values to the weights.
|
||||
initialize_dummy_weights(model, model_config)
|
||||
|
||||
@@ -21,12 +21,14 @@ Limitations:
|
||||
__all__ = [
|
||||
"record_metadata_for_reloading",
|
||||
"initialize_layerwise_reload",
|
||||
"finalize_layerwise_processing",
|
||||
"finalize_layerwise_reload",
|
||||
"set_torchao_reload_attrs",
|
||||
"support_quantized_model_reload_from_hp_weights",
|
||||
]
|
||||
|
||||
from .layerwise import (
|
||||
finalize_layerwise_processing,
|
||||
finalize_layerwise_reload,
|
||||
initialize_layerwise_reload,
|
||||
record_metadata_for_reloading,
|
||||
|
||||
@@ -28,6 +28,7 @@ __all__ = [
|
||||
"get_layerwise_info",
|
||||
"record_metadata_for_reloading",
|
||||
"initialize_layerwise_reload",
|
||||
"finalize_layerwise_processing",
|
||||
"finalize_layerwise_reload",
|
||||
]
|
||||
|
||||
@@ -89,7 +90,7 @@ def initialize_layerwise_reload(model: torch.nn.Module):
|
||||
info = get_layerwise_info(layer)
|
||||
|
||||
# Skip if the layer has already been initialized
|
||||
if info.can_process():
|
||||
if info.can_load():
|
||||
continue
|
||||
|
||||
# Save current tensors for later copying
|
||||
@@ -98,15 +99,21 @@ def initialize_layerwise_reload(model: torch.nn.Module):
|
||||
# Restore layer parameters/buffers onto meta device
|
||||
restore_layer_on_meta(layer, info)
|
||||
|
||||
# Track loading progress to determine when to process/copy
|
||||
info.load_numel = 0
|
||||
info.load_numel_total = get_layer_size(layer)
|
||||
initialize_online_processing(layer)
|
||||
|
||||
# Wrap each parameter's weight loader
|
||||
# Note that nested wrapping will occur for shared tensors
|
||||
for name, tensor in get_layer_tensors(layer).items():
|
||||
if _get_weight_loader(tensor).__name__ != "online_process_loader":
|
||||
tensor.weight_loader = make_online_process_loader(layer, name)
|
||||
|
||||
def initialize_online_processing(layer: torch.nn.Module):
|
||||
info = get_layerwise_info(layer)
|
||||
|
||||
# Track loading progress to determine when to process/copy
|
||||
info.load_numel = 0
|
||||
info.load_numel_total = get_layer_size(layer)
|
||||
|
||||
# Wrap each parameter's weight loader
|
||||
# Note that nested wrapping will occur for shared tensors
|
||||
for name, tensor in get_layer_tensors(layer).items():
|
||||
if _get_weight_loader(tensor).__name__ != "online_process_loader":
|
||||
tensor.weight_loader = make_online_process_loader(layer, name)
|
||||
|
||||
|
||||
def make_online_process_loader(layer: torch.nn.Module, param_name: str) -> Callable:
|
||||
@@ -118,7 +125,7 @@ def make_online_process_loader(layer: torch.nn.Module, param_name: str) -> Calla
|
||||
|
||||
@wraps(original_loader, assigned=("__doc__", "__annotations__"))
|
||||
def online_process_loader(*args, **kwargs):
|
||||
if not info.can_process():
|
||||
if not info.can_load():
|
||||
# Unfortunately, some qconfigs are set up to load the same weight
|
||||
# multiple times. For example, CT_WNA16 loads `weight_shape` for
|
||||
# each of the qkv partitions. This results in layers loading extra
|
||||
@@ -140,7 +147,7 @@ def make_online_process_loader(layer: torch.nn.Module, param_name: str) -> Calla
|
||||
bound_args = loader_signature.bind(*args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
|
||||
# Cache loaded weights, track loading progress
|
||||
# Buffer loaded weights, track loading progress
|
||||
info.loaded_weights.append((param_name, bound_args))
|
||||
num_loaded, ret = get_numel_loaded(original_loader, bound_args)
|
||||
info.load_numel += num_loaded
|
||||
@@ -163,19 +170,26 @@ def make_online_process_loader(layer: torch.nn.Module, param_name: str) -> Calla
|
||||
return online_process_loader
|
||||
|
||||
|
||||
def finalize_layerwise_reload(model: torch.nn.Module, model_config: ModelConfig):
|
||||
def finalize_layerwise_processing(model: torch.nn.Module, model_config: ModelConfig):
|
||||
"""
|
||||
Remove the outermost layer of weight loading wrappers.
|
||||
Apply processing to any layers which were not layerwise processed during loading.
|
||||
This includes attention layers and layers which have weight elements which are not
|
||||
loaded (due to padding).
|
||||
|
||||
This function should be applied after `initialize_layerwise_reload` is applied
|
||||
unwrap the layerwise weight loaders.
|
||||
|
||||
Also processes Attention/MLA layers, which must be processed after all other layers
|
||||
:param model: model to finalize processing for
|
||||
:param model_config: config needed for applying processing to attention layers
|
||||
"""
|
||||
model._do_torchao_reload = model._original_do_torchao_reload
|
||||
if hasattr(model, "_original_do_torchao_reload"):
|
||||
model._do_torchao_reload = model._original_do_torchao_reload
|
||||
|
||||
for layer in model.modules():
|
||||
info = get_layerwise_info(layer)
|
||||
if not info.can_load():
|
||||
info.reset()
|
||||
continue
|
||||
|
||||
# Attention/MLA layers are processed after all other layers
|
||||
if isinstance(layer, (Attention, MLAAttention)):
|
||||
@@ -184,17 +198,29 @@ def finalize_layerwise_reload(model: torch.nn.Module, model_config: ModelConfig)
|
||||
"Layerwise reloading of Q/K/V scale weights is not implemented yet"
|
||||
)
|
||||
|
||||
elif info.kernel_tensors is None:
|
||||
raise NotImplementedError(
|
||||
"Layerwise loading of Q/K/V scale weights is not implemented yet"
|
||||
)
|
||||
|
||||
else:
|
||||
_place_kernel_tensors(layer, info)
|
||||
layer.process_weights_after_loading(model_config.dtype)
|
||||
|
||||
# No weights were loaded, place kernel tensors back
|
||||
elif info.can_process() and info.load_numel <= 0:
|
||||
_place_kernel_tensors(layer, info)
|
||||
# No weights were loaded
|
||||
elif info.load_numel <= 0:
|
||||
# first load but received no weights. This happens on dummy load
|
||||
if info.kernel_tensors is None:
|
||||
materialize_layer(layer)
|
||||
|
||||
# reloading: place kernel tensors back as a fallback
|
||||
else:
|
||||
logger.warning("%s: Failed to load weights", layer.__class__.__name__)
|
||||
_place_kernel_tensors(layer, info)
|
||||
|
||||
# Process non-attention layers which did not load all elements. This can happen
|
||||
# if the created weight has extra padding elements which are not loaded
|
||||
# Having too many of these delayed layers can lead to execess memory usage
|
||||
# Having too many of these delayed layers can lead to excess memory usage
|
||||
# see Limitations(4)
|
||||
elif info.load_numel > 0 and info.load_numel < info.load_numel_total: # type: ignore[operator]
|
||||
logger.debug("%s: Delayed processing", layer.__class__.__name__)
|
||||
@@ -203,20 +229,24 @@ def finalize_layerwise_reload(model: torch.nn.Module, model_config: ModelConfig)
|
||||
info.reset()
|
||||
|
||||
|
||||
def finalize_layerwise_reload(*args, **kwargs):
|
||||
finalize_layerwise_processing(*args, **kwargs)
|
||||
|
||||
|
||||
def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
|
||||
"""
|
||||
Finalize layer loading after all weights have been cached.
|
||||
Finalize layer loading after all weights have been buffered.
|
||||
|
||||
This function:
|
||||
1. Materializes the layer onto the target device
|
||||
2. Loads all cached weights
|
||||
2. Loads all buffered weights
|
||||
3. Runs quantization processing if applicable
|
||||
4. Copies processed values back to original tensor storage
|
||||
"""
|
||||
# Materialize layer tensors onto device
|
||||
materialize_layer(layer)
|
||||
|
||||
# Reset FP8 online quantization flag so process_weights_after_loading
|
||||
# Reset online quantization flag so process_weights_after_loading
|
||||
# will run again during reload
|
||||
if hasattr(layer, "_already_called_process_weights_after_loading"):
|
||||
delattr(layer, "_already_called_process_weights_after_loading")
|
||||
@@ -225,7 +255,7 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
|
||||
for param in get_layer_tensors(layer).values():
|
||||
param.weight_loader = _get_original_loader(param)
|
||||
|
||||
# Load all cached weights into materialized layer (using original loaders)
|
||||
# Load all buffered weights into materialized layer (using original loaders)
|
||||
for name, args in info.loaded_weights:
|
||||
param = getattr(layer, name)
|
||||
args.arguments["param"] = param
|
||||
@@ -239,13 +269,14 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
|
||||
|
||||
# Copy processed values into original tensor storage (preserves cudagraph refs)
|
||||
# this code is a no-op if not reloading (because kernel tensors is empty)
|
||||
parameters, buffers = info.kernel_tensors
|
||||
for name, param in parameters.items():
|
||||
param.data.copy_(getattr(layer, name))
|
||||
for name, buffer in buffers.items():
|
||||
buffer.data.copy_(getattr(layer, name))
|
||||
if info.kernel_tensors is not None:
|
||||
parameters, buffers = info.kernel_tensors
|
||||
for name, param in parameters.items():
|
||||
param.data.copy_(getattr(layer, name))
|
||||
for name, buffer in buffers.items():
|
||||
buffer.data.copy_(getattr(layer, name))
|
||||
|
||||
_place_kernel_tensors(layer, info)
|
||||
_place_kernel_tensors(layer, info)
|
||||
|
||||
info.reset()
|
||||
logger.debug("%s: Processed", layer.__class__.__name__)
|
||||
@@ -268,6 +299,7 @@ def _place_kernel_tensors(layer: torch.nn.Module, info: LayerReloadingInfo):
|
||||
for name in get_layer_tensors(layer):
|
||||
delattr(layer, name)
|
||||
|
||||
assert info.kernel_tensors is not None
|
||||
parameters, buffers = info.kernel_tensors
|
||||
for name, param in parameters.items():
|
||||
layer.register_parameter(name, param)
|
||||
|
||||
@@ -104,7 +104,7 @@ def materialize_layer(layer: torch.nn.Module) -> None:
|
||||
setattr(layer, name, materialize_meta_tensor(tensor))
|
||||
|
||||
|
||||
class MetaCopyCounter(TorchDispatchMode):
|
||||
class CopyCounter(TorchDispatchMode):
|
||||
"""
|
||||
Tracks total number of elements modified with `copy_`.
|
||||
|
||||
@@ -122,7 +122,7 @@ class MetaCopyCounter(TorchDispatchMode):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
if func is torch.ops.aten.copy_.default and args[0].device.type == "meta":
|
||||
if func is torch.ops.aten.copy_.default:
|
||||
assert args[0].numel() == args[1].numel()
|
||||
self.copied_numel += args[0].numel()
|
||||
|
||||
@@ -140,7 +140,6 @@ def get_numel_loaded(
|
||||
:return: number of elements loaded by the weight loader, the return value of the
|
||||
weight loader
|
||||
"""
|
||||
assert args.arguments["param"].device.type == "meta"
|
||||
with MetaCopyCounter() as counter:
|
||||
with CopyCounter() as counter:
|
||||
return_value = weight_loader(*args.args, **args.kwargs)
|
||||
return counter.copied_numel, return_value
|
||||
|
||||
@@ -16,8 +16,8 @@ class LayerReloadingInfo:
|
||||
# model format (meta), populated by `record_metadata_for_reloading`
|
||||
restore_metadata: LayerTensors = field(default_factory=lambda: ({}, {}))
|
||||
|
||||
# kernel format (device)
|
||||
kernel_tensors: LayerTensors = field(default_factory=lambda: ({}, {}))
|
||||
# kernel format (device), used to copy into when reloading only
|
||||
kernel_tensors: LayerTensors | None = None
|
||||
|
||||
# track how many restored elements are ready for loading
|
||||
load_numel: int = 0
|
||||
@@ -29,5 +29,5 @@ class LayerReloadingInfo:
|
||||
def reset(self):
|
||||
self.__init__(restore_metadata=self.restore_metadata) # type: ignore[misc]
|
||||
|
||||
def can_process(self) -> bool:
|
||||
def can_load(self) -> bool:
|
||||
return self.load_numel_total is not None
|
||||
|
||||
@@ -1323,25 +1323,11 @@ def initialize_dummy_weights(
|
||||
is fixed, the random values generated by this function only depends on
|
||||
the parameter's number of elements and its data type.
|
||||
"""
|
||||
|
||||
# Check if any module uses online quantization with meta device weights.
|
||||
# If so, we'll skip initializing params on meta device since they'll be
|
||||
# handled in `process_weights_after_loading`.
|
||||
def uses_meta_device(module: torch.nn.Module) -> bool:
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
return getattr(quant_method, "uses_meta_device", False)
|
||||
|
||||
has_online_quant = any(uses_meta_device(m) for m in model.modules())
|
||||
|
||||
for param in model.state_dict().values():
|
||||
if has_online_quant and param.device == torch.device("meta"):
|
||||
# For online quantization, weights are created on meta device and
|
||||
# dummy weight init will happen in `process_weights_after_loading`.
|
||||
continue
|
||||
|
||||
initialize_single_dummy_weight(param, low, high, seed)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def initialize_single_dummy_weight(
|
||||
param: torch.Tensor,
|
||||
low: float = -1e-3,
|
||||
|
||||
Reference in New Issue
Block a user