fix memory for online fp8 quantization with streaming weight load (#31914)
Signed-off-by: vasiliy <vasiliy@fb.com>
This commit is contained in:
committed by
GitHub
parent
5d1aef3004
commit
0130223bd9
@@ -86,6 +86,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
cutlass_fp8_supported,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import initialize_single_dummy_weight
|
||||
from vllm.model_executor.parameter import (
|
||||
BlockQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
@@ -293,6 +294,16 @@ class CopyNumelCounter(TorchDispatchMode):
|
||||
return out
|
||||
|
||||
|
||||
def _copy_missing_attrs(old: torch.Tensor, new: torch.Tensor) -> None:
|
||||
"""Copies any attrs present in `old` but not in `new` to `new`"""
|
||||
new_attrs = set(dir(new))
|
||||
attrs_to_set = {}
|
||||
for attr in dir(old):
|
||||
if attr not in new_attrs:
|
||||
attrs_to_set[attr] = getattr(old, attr)
|
||||
set_weight_attrs(new, attrs_to_set)
|
||||
|
||||
|
||||
class Fp8LinearMethod(LinearMethodBase):
|
||||
"""Linear method for FP8.
|
||||
Supports loading FP8 checkpoints with static weight scale and
|
||||
@@ -578,6 +589,22 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
|
||||
if not hasattr(layer, "_loaded_numel"):
|
||||
layer._loaded_numel = 0
|
||||
|
||||
# when the first `loaded_weight` is about to be
|
||||
# loaded to `param`, materialize `param` just-in-time
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty_like(layer.weight, device=layer._load_device),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=patched_weight_loader,
|
||||
)
|
||||
_copy_missing_attrs(layer.weight, weight)
|
||||
layer.register_parameter("weight", weight)
|
||||
del layer._load_device
|
||||
|
||||
# refresh the reference to `param` to reflect just-in-time
|
||||
# materialization
|
||||
param = layer.weight
|
||||
|
||||
# load the current weight chunk
|
||||
copy_numel_counter = CopyNumelCounter()
|
||||
with copy_numel_counter:
|
||||
@@ -590,30 +617,50 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
|
||||
if layer._loaded_numel == target_loaded_numel:
|
||||
self.process_weights_after_loading(layer)
|
||||
|
||||
# Delete the bookkeeping
|
||||
del layer._loaded_numel
|
||||
# Prevent the usual `process_weights_after_loading` call from doing
|
||||
# anything
|
||||
layer._already_called_process_weights_after_loading = True
|
||||
|
||||
# Note that we keep `layer._loaded_numel` around just in case
|
||||
# there is logic added to vllm in the future which calls a
|
||||
# weight loader twice - we do not want to re-initialize in
|
||||
# that case.
|
||||
|
||||
return res
|
||||
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
# materialized just-in-time in `patched_weight_loader`
|
||||
device="meta",
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=patched_weight_loader,
|
||||
)
|
||||
# stash the correct device for `patched_weight_loader`
|
||||
layer._load_device = torch.get_default_device()
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||
return
|
||||
|
||||
# deferred initialization of randomly initialized weights for the
|
||||
# `--load_format dummy` feature
|
||||
if layer.weight.device == torch.device("meta"):
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty_like(layer.weight, device=layer._load_device),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=layer.weight.weight_loader,
|
||||
)
|
||||
_copy_missing_attrs(layer.weight, weight)
|
||||
layer.register_parameter("weight", weight)
|
||||
initialize_single_dummy_weight(layer.weight)
|
||||
|
||||
# TODO(future): support block_quant in online quant path
|
||||
assert not self.block_quant
|
||||
|
||||
@@ -1069,6 +1116,39 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
if not hasattr(layer, "_loaded_numel"):
|
||||
layer._loaded_numel = 0
|
||||
|
||||
# save the ids of original w13 and w2 so that we can
|
||||
# distinguish which one `param` should map to further
|
||||
# down in this file
|
||||
layer._w13_weight_orig_id = id(layer.w13_weight)
|
||||
layer._w2_weight_orig_id = id(layer.w2_weight)
|
||||
|
||||
# when the first `loaded_weight` is about to be
|
||||
# loaded to `param`, materialize `param` just-in-time
|
||||
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty_like(layer.w13_weight, device=layer._load_device),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
_copy_missing_attrs(layer.w13_weight, w13_weight)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty_like(layer.w2_weight, device=layer._load_device),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
_copy_missing_attrs(layer.w2_weight, w2_weight)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
del layer._load_device
|
||||
|
||||
# refresh the reference to `param` to reflect just-in-time
|
||||
# materialization
|
||||
if id(param) == layer._w13_weight_orig_id:
|
||||
param = layer.w13_weight
|
||||
elif id(param) == layer._w2_weight_orig_id:
|
||||
param = layer.w2_weight
|
||||
|
||||
# load the current weight chunk
|
||||
copy_numel_counter = CopyNumelCounter()
|
||||
with copy_numel_counter:
|
||||
@@ -1081,12 +1161,16 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
if layer._loaded_numel == target_loaded_numel:
|
||||
self.process_weights_after_loading(layer)
|
||||
|
||||
# Delete the bookkeeping
|
||||
del layer._loaded_numel
|
||||
# Prevent the usual `process_weights_after_loading` call
|
||||
# from doing anything
|
||||
layer._already_called_process_weights_after_loading = True
|
||||
|
||||
# Note that we keep `layer._loaded_numel`,
|
||||
# `layer._w13_weight_orig_id` and `layer._w2_weight_orig_id`
|
||||
# around because if EP is on, weight loaders for non-local
|
||||
# experts will run but not actually copy any elements, and we
|
||||
# need to not re-initialize in that case.
|
||||
|
||||
return res
|
||||
|
||||
new_extra_weight_attrs["weight_loader"] = patched_weight_loader
|
||||
@@ -1098,6 +1182,8 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
# materialized just-in-time in `patched_weight_loader`
|
||||
device="meta",
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
@@ -1110,12 +1196,16 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
# materialized just-in-time in `patched_weight_loader`
|
||||
device="meta",
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
# stash the correct device for `patched_weight_loader`
|
||||
layer._load_device = torch.get_default_device()
|
||||
|
||||
# WEIGHT_SCALES
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
@@ -1138,6 +1228,31 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||
return
|
||||
|
||||
# deferred initialization of randomly initialized weights for the
|
||||
# `--load_format dummy` feature
|
||||
if layer.w13_weight.device == torch.device("meta"):
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty_like(layer.w13_weight, device=layer._load_device),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
w13_weight, {"weight_loader": layer.w13_weight.weight_loader}
|
||||
)
|
||||
_copy_missing_attrs(layer.w13_weight, w13_weight)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
initialize_single_dummy_weight(layer.w13_weight)
|
||||
if layer.w2_weight.device == torch.device("meta"):
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty_like(layer.w2_weight, device=layer._load_device),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
w2_weight, {"weight_loader": layer.w2_weight.weight_loader}
|
||||
)
|
||||
_copy_missing_attrs(layer.w2_weight, w2_weight)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
initialize_single_dummy_weight(layer.w2_weight)
|
||||
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
|
||||
|
||||
Reference in New Issue
Block a user