fix memory for online fp8 quantization with streaming weight load (#31914)

Signed-off-by: vasiliy <vasiliy@fb.com>
This commit is contained in:
Vasiliy Kuznetsov
2026-02-02 14:17:42 -05:00
committed by GitHub
parent 5d1aef3004
commit 0130223bd9
5 changed files with 283 additions and 38 deletions

View File

@@ -86,6 +86,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_fp8_supported,
normalize_e4m3fn_to_e4m3fnuz,
)
from vllm.model_executor.model_loader.weight_utils import initialize_single_dummy_weight
from vllm.model_executor.parameter import (
BlockQuantScaleParameter,
ModelWeightParameter,
@@ -293,6 +294,16 @@ class CopyNumelCounter(TorchDispatchMode):
return out
def _copy_missing_attrs(old: torch.Tensor, new: torch.Tensor) -> None:
"""Copies any attrs present in `old` but not in `new` to `new`"""
new_attrs = set(dir(new))
attrs_to_set = {}
for attr in dir(old):
if attr not in new_attrs:
attrs_to_set[attr] = getattr(old, attr)
set_weight_attrs(new, attrs_to_set)
class Fp8LinearMethod(LinearMethodBase):
"""Linear method for FP8.
Supports loading FP8 checkpoints with static weight scale and
@@ -578,6 +589,22 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0
# when the first `loaded_weight` is about to be
# loaded to `param`, materialize `param` just-in-time
weight = ModelWeightParameter(
data=torch.empty_like(layer.weight, device=layer._load_device),
input_dim=1,
output_dim=0,
weight_loader=patched_weight_loader,
)
_copy_missing_attrs(layer.weight, weight)
layer.register_parameter("weight", weight)
del layer._load_device
# refresh the reference to `param` to reflect just-in-time
# materialization
param = layer.weight
# load the current weight chunk
copy_numel_counter = CopyNumelCounter()
with copy_numel_counter:
@@ -590,30 +617,50 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
if layer._loaded_numel == target_loaded_numel:
self.process_weights_after_loading(layer)
# Delete the bookkeeping
del layer._loaded_numel
# Prevent the usual `process_weights_after_loading` call from doing
# anything
layer._already_called_process_weights_after_loading = True
# Note that we keep `layer._loaded_numel` around just in case
# there is logic added to vllm in the future which calls a
# weight loader twice - we do not want to re-initialize in
# that case.
return res
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
# materialized just-in-time in `patched_weight_loader`
device="meta",
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=patched_weight_loader,
)
# stash the correct device for `patched_weight_loader`
layer._load_device = torch.get_default_device()
layer.register_parameter("weight", weight)
def process_weights_after_loading(self, layer: Module) -> None:
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
# deferred initialization of randomly initialized weights for the
# `--load_format dummy` feature
if layer.weight.device == torch.device("meta"):
weight = ModelWeightParameter(
data=torch.empty_like(layer.weight, device=layer._load_device),
input_dim=1,
output_dim=0,
weight_loader=layer.weight.weight_loader,
)
_copy_missing_attrs(layer.weight, weight)
layer.register_parameter("weight", weight)
initialize_single_dummy_weight(layer.weight)
# TODO(future): support block_quant in online quant path
assert not self.block_quant
@@ -1069,6 +1116,39 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
if not hasattr(layer, "_loaded_numel"):
layer._loaded_numel = 0
# save the ids of original w13 and w2 so that we can
# distinguish which one `param` should map to further
# down in this file
layer._w13_weight_orig_id = id(layer.w13_weight)
layer._w2_weight_orig_id = id(layer.w2_weight)
# when the first `loaded_weight` is about to be
# loaded to `param`, materialize `param` just-in-time
w13_weight = torch.nn.Parameter(
torch.empty_like(layer.w13_weight, device=layer._load_device),
requires_grad=False,
)
set_weight_attrs(w13_weight, extra_weight_attrs)
_copy_missing_attrs(layer.w13_weight, w13_weight)
layer.register_parameter("w13_weight", w13_weight)
w2_weight = torch.nn.Parameter(
torch.empty_like(layer.w2_weight, device=layer._load_device),
requires_grad=False,
)
set_weight_attrs(w2_weight, extra_weight_attrs)
_copy_missing_attrs(layer.w2_weight, w2_weight)
layer.register_parameter("w2_weight", w2_weight)
del layer._load_device
# refresh the reference to `param` to reflect just-in-time
# materialization
if id(param) == layer._w13_weight_orig_id:
param = layer.w13_weight
elif id(param) == layer._w2_weight_orig_id:
param = layer.w2_weight
# load the current weight chunk
copy_numel_counter = CopyNumelCounter()
with copy_numel_counter:
@@ -1081,12 +1161,16 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
if layer._loaded_numel == target_loaded_numel:
self.process_weights_after_loading(layer)
# Delete the bookkeeping
del layer._loaded_numel
# Prevent the usual `process_weights_after_loading` call
# from doing anything
layer._already_called_process_weights_after_loading = True
# Note that we keep `layer._loaded_numel`,
# `layer._w13_weight_orig_id` and `layer._w2_weight_orig_id`
# around because if EP is on, weight loaders for non-local
# experts will run but not actually copy any elements, and we
# need to not re-initialize in that case.
return res
new_extra_weight_attrs["weight_loader"] = patched_weight_loader
@@ -1098,6 +1182,8 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
# materialized just-in-time in `patched_weight_loader`
device="meta",
dtype=params_dtype,
),
requires_grad=False,
@@ -1110,12 +1196,16 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
num_experts,
hidden_size,
intermediate_size_per_partition,
# materialized just-in-time in `patched_weight_loader`
device="meta",
dtype=params_dtype,
),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# stash the correct device for `patched_weight_loader`
layer._load_device = torch.get_default_device()
# WEIGHT_SCALES
# Allocate 2 scales for w1 and w3 respectively.
@@ -1138,6 +1228,31 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
# deferred initialization of randomly initialized weights for the
# `--load_format dummy` feature
if layer.w13_weight.device == torch.device("meta"):
w13_weight = torch.nn.Parameter(
torch.empty_like(layer.w13_weight, device=layer._load_device),
requires_grad=False,
)
set_weight_attrs(
w13_weight, {"weight_loader": layer.w13_weight.weight_loader}
)
_copy_missing_attrs(layer.w13_weight, w13_weight)
layer.register_parameter("w13_weight", w13_weight)
initialize_single_dummy_weight(layer.w13_weight)
if layer.w2_weight.device == torch.device("meta"):
w2_weight = torch.nn.Parameter(
torch.empty_like(layer.w2_weight, device=layer._load_device),
requires_grad=False,
)
set_weight_attrs(
w2_weight, {"weight_loader": layer.w2_weight.weight_loader}
)
_copy_missing_attrs(layer.w2_weight, w2_weight)
layer.register_parameter("w2_weight", w2_weight)
initialize_single_dummy_weight(layer.w2_weight)
# If checkpoint is fp16, quantize in place.
fp8_dtype = current_platform.fp8_dtype()
w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)