fix memory for online fp8 quantization with streaming weight load (#31914)
Signed-off-by: vasiliy <vasiliy@fb.com>
This commit is contained in:
committed by
GitHub
parent
5d1aef3004
commit
0130223bd9
@@ -5,7 +5,10 @@
|
||||
Run `pytest tests/quantization/test_fp8.py --forked`.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
import regex as re
|
||||
import torch
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
@@ -195,6 +198,99 @@ def test_online_quantization(
|
||||
print(outputs[0][1])
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("fp8"),
|
||||
reason="FP8 is not supported on this GPU type.",
|
||||
)
|
||||
def test_online_quant_peak_mem(
|
||||
vllm_runner,
|
||||
caplog_mp_spawn,
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
# Note: `allenai/OLMoE-1B-7B-0125-Instruct` was selected because:
|
||||
# 1. it covers both Linear and MoE paths
|
||||
# 2. it is already used by other tests in CI, so adding it here
|
||||
# does not increase disk space for CI runners
|
||||
# I really wanted to use `ibm-granite/granite-3.0-1b-a400m-base`
|
||||
# which I think is the smallest MoE model in vLLM (2.5 GiB bf16,
|
||||
# 1.3 GiB fp8), but could not as adding one more model makes CI
|
||||
# run out of disk space.
|
||||
model_name = "allenai/OLMoE-1B-7B-0125-Instruct"
|
||||
|
||||
# Force spawn to ensure caplog_mp_spawn works consistently
|
||||
# (it relies on VLLM_LOGGING_CONFIG_PATH which spawn reads but fork ignores)
|
||||
monkeypatch.setenv("VLLM_WORKER_MULTIPROC_METHOD", "spawn")
|
||||
|
||||
with (
|
||||
caplog_mp_spawn(logging.DEBUG) as log_holder,
|
||||
vllm_runner(
|
||||
model_name,
|
||||
quantization="fp8",
|
||||
enforce_eager=True,
|
||||
) as llm,
|
||||
):
|
||||
outputs = llm.generate_greedy(["The future of AI is"], max_tokens=4)
|
||||
print(outputs[0][1])
|
||||
|
||||
log_text = log_holder.text
|
||||
|
||||
# Parse memory usage from captured logs
|
||||
model_memory_gib = None
|
||||
peak_memory_gib = None
|
||||
for line in log_text.splitlines():
|
||||
if model_memory_gib is None:
|
||||
match = re.search(r"Model loading took ([\d.]+) GiB memory", line)
|
||||
if match:
|
||||
model_memory_gib = float(match.group(1))
|
||||
if peak_memory_gib is None:
|
||||
match = re.search(
|
||||
r"Peak GPU memory after loading weights: ([\d.]+) GiB", line
|
||||
)
|
||||
if match:
|
||||
peak_memory_gib = float(match.group(1))
|
||||
|
||||
assert model_memory_gib is not None, "Could not find model loading memory log"
|
||||
assert peak_memory_gib is not None, "Could not find peak memory log"
|
||||
print(f"GPU memory used after loading weights: {model_memory_gib} GiB")
|
||||
print(f"Peak GPU memory usage while loading weights: {peak_memory_gib} GiB")
|
||||
|
||||
# model specific, allenai/OLMoE-1B-7B-0125-Instruct fp8 online quant
|
||||
# uses 6.65 GiB for weight loading (bf16 checkpoint is ~12.89 GiB)
|
||||
expected_model_memory_gib = 6.7
|
||||
|
||||
# for allenai/OLMoE-1B-7B-0125-Instruct the number we see today is 9.06
|
||||
# GiB, which is 1.36x above model_memory_gib. A slightly higher number is
|
||||
# expected as when we load and quantize weights in a streaming fashion we
|
||||
# need to have individual weights in bf16 + fp8 alive at the same time.
|
||||
expected_peak_memory_gib = expected_model_memory_gib * 1.4
|
||||
|
||||
assert model_memory_gib < expected_model_memory_gib, (
|
||||
f"{model_memory_gib=} higher than {expected_model_memory_gib}"
|
||||
)
|
||||
assert peak_memory_gib < expected_peak_memory_gib, (
|
||||
f"{peak_memory_gib=} higher than {expected_peak_memory_gib}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("fp8"),
|
||||
reason="FP8 is not supported on this GPU type.",
|
||||
)
|
||||
def test_online_quant_load_format_dummy(
|
||||
vllm_runner,
|
||||
monkeypatch,
|
||||
caplog,
|
||||
) -> None:
|
||||
with vllm_runner(
|
||||
"ibm-granite/granite-3.0-1b-a400m-base",
|
||||
quantization="fp8",
|
||||
enforce_eager=True,
|
||||
load_format="dummy",
|
||||
) as llm:
|
||||
outputs = llm.generate_greedy(["The future of AI is"], max_tokens=4)
|
||||
print(outputs[0][1])
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not is_quant_method_supported("fp8"),
|
||||
reason="FP8 is not supported on this GPU type.",
|
||||
|
||||
@@ -86,6 +86,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
cutlass_fp8_supported,
|
||||
normalize_e4m3fn_to_e4m3fnuz,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import initialize_single_dummy_weight
|
||||
from vllm.model_executor.parameter import (
|
||||
BlockQuantScaleParameter,
|
||||
ModelWeightParameter,
|
||||
@@ -293,6 +294,16 @@ class CopyNumelCounter(TorchDispatchMode):
|
||||
return out
|
||||
|
||||
|
||||
def _copy_missing_attrs(old: torch.Tensor, new: torch.Tensor) -> None:
|
||||
"""Copies any attrs present in `old` but not in `new` to `new`"""
|
||||
new_attrs = set(dir(new))
|
||||
attrs_to_set = {}
|
||||
for attr in dir(old):
|
||||
if attr not in new_attrs:
|
||||
attrs_to_set[attr] = getattr(old, attr)
|
||||
set_weight_attrs(new, attrs_to_set)
|
||||
|
||||
|
||||
class Fp8LinearMethod(LinearMethodBase):
|
||||
"""Linear method for FP8.
|
||||
Supports loading FP8 checkpoints with static weight scale and
|
||||
@@ -578,6 +589,22 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
|
||||
if not hasattr(layer, "_loaded_numel"):
|
||||
layer._loaded_numel = 0
|
||||
|
||||
# when the first `loaded_weight` is about to be
|
||||
# loaded to `param`, materialize `param` just-in-time
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty_like(layer.weight, device=layer._load_device),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=patched_weight_loader,
|
||||
)
|
||||
_copy_missing_attrs(layer.weight, weight)
|
||||
layer.register_parameter("weight", weight)
|
||||
del layer._load_device
|
||||
|
||||
# refresh the reference to `param` to reflect just-in-time
|
||||
# materialization
|
||||
param = layer.weight
|
||||
|
||||
# load the current weight chunk
|
||||
copy_numel_counter = CopyNumelCounter()
|
||||
with copy_numel_counter:
|
||||
@@ -590,30 +617,50 @@ class Fp8OnlineLinearMethod(Fp8LinearMethod):
|
||||
if layer._loaded_numel == target_loaded_numel:
|
||||
self.process_weights_after_loading(layer)
|
||||
|
||||
# Delete the bookkeeping
|
||||
del layer._loaded_numel
|
||||
# Prevent the usual `process_weights_after_loading` call from doing
|
||||
# anything
|
||||
layer._already_called_process_weights_after_loading = True
|
||||
|
||||
# Note that we keep `layer._loaded_numel` around just in case
|
||||
# there is logic added to vllm in the future which calls a
|
||||
# weight loader twice - we do not want to re-initialize in
|
||||
# that case.
|
||||
|
||||
return res
|
||||
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty(
|
||||
output_size_per_partition,
|
||||
input_size_per_partition,
|
||||
# materialized just-in-time in `patched_weight_loader`
|
||||
device="meta",
|
||||
dtype=params_dtype,
|
||||
),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=patched_weight_loader,
|
||||
)
|
||||
# stash the correct device for `patched_weight_loader`
|
||||
layer._load_device = torch.get_default_device()
|
||||
layer.register_parameter("weight", weight)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||
return
|
||||
|
||||
# deferred initialization of randomly initialized weights for the
|
||||
# `--load_format dummy` feature
|
||||
if layer.weight.device == torch.device("meta"):
|
||||
weight = ModelWeightParameter(
|
||||
data=torch.empty_like(layer.weight, device=layer._load_device),
|
||||
input_dim=1,
|
||||
output_dim=0,
|
||||
weight_loader=layer.weight.weight_loader,
|
||||
)
|
||||
_copy_missing_attrs(layer.weight, weight)
|
||||
layer.register_parameter("weight", weight)
|
||||
initialize_single_dummy_weight(layer.weight)
|
||||
|
||||
# TODO(future): support block_quant in online quant path
|
||||
assert not self.block_quant
|
||||
|
||||
@@ -1069,6 +1116,39 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
if not hasattr(layer, "_loaded_numel"):
|
||||
layer._loaded_numel = 0
|
||||
|
||||
# save the ids of original w13 and w2 so that we can
|
||||
# distinguish which one `param` should map to further
|
||||
# down in this file
|
||||
layer._w13_weight_orig_id = id(layer.w13_weight)
|
||||
layer._w2_weight_orig_id = id(layer.w2_weight)
|
||||
|
||||
# when the first `loaded_weight` is about to be
|
||||
# loaded to `param`, materialize `param` just-in-time
|
||||
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty_like(layer.w13_weight, device=layer._load_device),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(w13_weight, extra_weight_attrs)
|
||||
_copy_missing_attrs(layer.w13_weight, w13_weight)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty_like(layer.w2_weight, device=layer._load_device),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
_copy_missing_attrs(layer.w2_weight, w2_weight)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
del layer._load_device
|
||||
|
||||
# refresh the reference to `param` to reflect just-in-time
|
||||
# materialization
|
||||
if id(param) == layer._w13_weight_orig_id:
|
||||
param = layer.w13_weight
|
||||
elif id(param) == layer._w2_weight_orig_id:
|
||||
param = layer.w2_weight
|
||||
|
||||
# load the current weight chunk
|
||||
copy_numel_counter = CopyNumelCounter()
|
||||
with copy_numel_counter:
|
||||
@@ -1081,12 +1161,16 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
if layer._loaded_numel == target_loaded_numel:
|
||||
self.process_weights_after_loading(layer)
|
||||
|
||||
# Delete the bookkeeping
|
||||
del layer._loaded_numel
|
||||
# Prevent the usual `process_weights_after_loading` call
|
||||
# from doing anything
|
||||
layer._already_called_process_weights_after_loading = True
|
||||
|
||||
# Note that we keep `layer._loaded_numel`,
|
||||
# `layer._w13_weight_orig_id` and `layer._w2_weight_orig_id`
|
||||
# around because if EP is on, weight loaders for non-local
|
||||
# experts will run but not actually copy any elements, and we
|
||||
# need to not re-initialize in that case.
|
||||
|
||||
return res
|
||||
|
||||
new_extra_weight_attrs["weight_loader"] = patched_weight_loader
|
||||
@@ -1098,6 +1182,8 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
num_experts,
|
||||
2 * intermediate_size_per_partition,
|
||||
hidden_size,
|
||||
# materialized just-in-time in `patched_weight_loader`
|
||||
device="meta",
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
@@ -1110,12 +1196,16 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
num_experts,
|
||||
hidden_size,
|
||||
intermediate_size_per_partition,
|
||||
# materialized just-in-time in `patched_weight_loader`
|
||||
device="meta",
|
||||
dtype=params_dtype,
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
set_weight_attrs(w2_weight, extra_weight_attrs)
|
||||
# stash the correct device for `patched_weight_loader`
|
||||
layer._load_device = torch.get_default_device()
|
||||
|
||||
# WEIGHT_SCALES
|
||||
# Allocate 2 scales for w1 and w3 respectively.
|
||||
@@ -1138,6 +1228,31 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||
return
|
||||
|
||||
# deferred initialization of randomly initialized weights for the
|
||||
# `--load_format dummy` feature
|
||||
if layer.w13_weight.device == torch.device("meta"):
|
||||
w13_weight = torch.nn.Parameter(
|
||||
torch.empty_like(layer.w13_weight, device=layer._load_device),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
w13_weight, {"weight_loader": layer.w13_weight.weight_loader}
|
||||
)
|
||||
_copy_missing_attrs(layer.w13_weight, w13_weight)
|
||||
layer.register_parameter("w13_weight", w13_weight)
|
||||
initialize_single_dummy_weight(layer.w13_weight)
|
||||
if layer.w2_weight.device == torch.device("meta"):
|
||||
w2_weight = torch.nn.Parameter(
|
||||
torch.empty_like(layer.w2_weight, device=layer._load_device),
|
||||
requires_grad=False,
|
||||
)
|
||||
set_weight_attrs(
|
||||
w2_weight, {"weight_loader": layer.w2_weight.weight_loader}
|
||||
)
|
||||
_copy_missing_attrs(layer.w2_weight, w2_weight)
|
||||
layer.register_parameter("w2_weight", w2_weight)
|
||||
initialize_single_dummy_weight(layer.w2_weight)
|
||||
|
||||
# If checkpoint is fp16, quantize in place.
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
|
||||
|
||||
@@ -13,6 +13,8 @@ from vllm.model_executor.model_loader.utils import (
|
||||
initialize_model,
|
||||
process_weights_after_loading,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.mem_utils import format_gib
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -56,6 +58,17 @@ class BaseModelLoader(ABC):
|
||||
logger.debug("Loading weights on %s ...", load_device)
|
||||
# Quantization does not happen in `load_weights` but after it
|
||||
self.load_weights(model, model_config)
|
||||
|
||||
# Log peak GPU memory after loading weights. This is needed
|
||||
# to have test coverage on peak memory for online quantization.
|
||||
if current_platform.is_cuda():
|
||||
peak_memory = torch.cuda.max_memory_allocated()
|
||||
logger.debug_once(
|
||||
"Peak GPU memory after loading weights: %s GiB",
|
||||
format_gib(peak_memory),
|
||||
scope="local",
|
||||
)
|
||||
|
||||
process_weights_after_loading(model, model_config, target_device)
|
||||
|
||||
return model.eval()
|
||||
|
||||
@@ -25,4 +25,4 @@ class DummyModelLoader(BaseModelLoader):
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||
# random values to the weights.
|
||||
initialize_dummy_weights(model)
|
||||
initialize_dummy_weights(model, model_config)
|
||||
|
||||
@@ -1059,6 +1059,7 @@ def composed_weight_loader(
|
||||
|
||||
def initialize_dummy_weights(
|
||||
model: torch.nn.Module,
|
||||
model_config: ModelConfig,
|
||||
low: float = -1e-3,
|
||||
high: float = 1e-3,
|
||||
seed: int = 1234,
|
||||
@@ -1075,41 +1076,61 @@ def initialize_dummy_weights(
|
||||
is fixed, the random values generated by this function only depends on
|
||||
the parameter's number of elements and its data type.
|
||||
"""
|
||||
for param in model.state_dict().values():
|
||||
if torch.is_floating_point(param):
|
||||
if current_platform.is_tpu():
|
||||
generator = torch.Generator(device="cpu")
|
||||
generator.manual_seed(seed)
|
||||
# Note: The param.uniform_ function cannot be used in this
|
||||
# context because it demands more TPU HBM than directly copying
|
||||
# from a CPU tensor.
|
||||
# Note: We avoid using torch.rank_like as it doesn't currently
|
||||
# support the generator argument.
|
||||
param.copy_(
|
||||
(high - low)
|
||||
* torch.rand(
|
||||
param.shape,
|
||||
generator=generator,
|
||||
dtype=param.dtype,
|
||||
layout=param.layout,
|
||||
requires_grad=param.requires_grad,
|
||||
device="cpu",
|
||||
)
|
||||
+ low
|
||||
)
|
||||
torch._sync(param)
|
||||
continue
|
||||
# TODO(future PR): make the check below more generic as more online
|
||||
# quant backends are added
|
||||
is_fp8_py_quant = model_config.quantization == "fp8"
|
||||
|
||||
generator = torch.Generator(device=param.data.device)
|
||||
for param in model.state_dict().values():
|
||||
if is_fp8_py_quant and param.device == torch.device("meta"):
|
||||
# for fp8.py's online quantization, dummy weight init will happen
|
||||
# in `process_weights_after_loading`.
|
||||
# TODO(future PR): consider refactoring dummy model init to compose
|
||||
# better with online quantization
|
||||
continue
|
||||
|
||||
initialize_single_dummy_weight(param, low, high, seed)
|
||||
|
||||
|
||||
def initialize_single_dummy_weight(
|
||||
param: torch.Tensor,
|
||||
low: float = -1e-3,
|
||||
high: float = 1e-3,
|
||||
seed: int = 1234,
|
||||
) -> None:
|
||||
if torch.is_floating_point(param):
|
||||
if current_platform.is_tpu():
|
||||
generator = torch.Generator(device="cpu")
|
||||
generator.manual_seed(seed)
|
||||
if torch.finfo(param.data.dtype).bits < 16:
|
||||
# uniform_ doesn't support < 16-bit datatypes (FP8)
|
||||
dtype = param.data.dtype
|
||||
tmp_param = param.data.to(torch.float16)
|
||||
tmp_param = tmp_param.uniform_(low, high, generator=generator).to(dtype)
|
||||
param.data.copy_(tmp_param)
|
||||
else:
|
||||
param.uniform_(low, high, generator=generator)
|
||||
# Note: The param.uniform_ function cannot be used in this
|
||||
# context because it demands more TPU HBM than directly copying
|
||||
# from a CPU tensor.
|
||||
# Note: We avoid using torch.rank_like as it doesn't currently
|
||||
# support the generator argument.
|
||||
param.copy_(
|
||||
(high - low)
|
||||
* torch.rand(
|
||||
param.shape,
|
||||
generator=generator,
|
||||
dtype=param.dtype,
|
||||
layout=param.layout,
|
||||
requires_grad=param.requires_grad,
|
||||
device="cpu",
|
||||
)
|
||||
+ low
|
||||
)
|
||||
torch._sync(param)
|
||||
return
|
||||
|
||||
generator = torch.Generator(device=param.data.device)
|
||||
generator.manual_seed(seed)
|
||||
if torch.finfo(param.data.dtype).bits < 16:
|
||||
# uniform_ doesn't support < 16-bit datatypes (FP8)
|
||||
dtype = param.data.dtype
|
||||
tmp_param = param.data.to(torch.float16)
|
||||
tmp_param = tmp_param.uniform_(low, high, generator=generator).to(dtype)
|
||||
param.data.copy_(tmp_param)
|
||||
else:
|
||||
param.uniform_(low, high, generator=generator)
|
||||
|
||||
|
||||
def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> str | None:
|
||||
|
||||
Reference in New Issue
Block a user