[Feature] Support CPU Offloading without Pytorch Pinned Memory that leads to doubled allocation (#32993)
Signed-off-by: wzhao18 <wzhao18.sz@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -11,6 +11,7 @@ import torch
|
||||
from torch import nn
|
||||
from typing_extensions import assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention, MLAAttention
|
||||
@@ -25,6 +26,7 @@ from vllm.model_executor.model_loader.reload import (
|
||||
from vllm.model_executor.models.interfaces import SupportsQuant
|
||||
from vllm.tracing import instrument
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.utils.torch_utils import get_accelerator_view_from_cpu_tensor
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -111,7 +113,8 @@ def process_weights_after_loading(
|
||||
):
|
||||
# TODO(lucas): see if there is a way to unify the signatures
|
||||
# of process_weights_after_loading
|
||||
module.process_weights_after_loading(model_config.dtype)
|
||||
with device_loading_context(module, target_device):
|
||||
module.process_weights_after_loading(model_config.dtype)
|
||||
|
||||
# Needed for torchao model reloading via model.reload_weights
|
||||
# @kylesayrs @jerryzh168 this can be removed if callers move to `reload_weights`
|
||||
@@ -127,38 +130,41 @@ def device_loading_context(module: torch.nn.Module, target_device: torch.device)
|
||||
return
|
||||
|
||||
original_device_states: dict[str, torch.device] = {}
|
||||
uva_offloaded_parameters: list[str] = []
|
||||
|
||||
# Store original device states and move parameters to GPU if they're on CPU
|
||||
for name, p in module.named_parameters():
|
||||
if p.device.type == "cpu":
|
||||
original_device_states[name] = p.device
|
||||
p.data = p.data.to(target_device)
|
||||
if getattr(p, "_vllm_is_uva_offloaded", False):
|
||||
uva_offloaded_parameters.append(name)
|
||||
# Parameters already on target device are not touched
|
||||
|
||||
try:
|
||||
yield module
|
||||
|
||||
finally:
|
||||
use_pin_memory = (
|
||||
is_pin_memory_available()
|
||||
and not envs.VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY
|
||||
)
|
||||
# Restore parameters to their original devices, ignoring new parameters
|
||||
pin_memory = is_pin_memory_available()
|
||||
for name, p in module.named_parameters():
|
||||
if name in original_device_states:
|
||||
original_device: torch.device = original_device_states[name]
|
||||
if original_device.type == "cpu":
|
||||
# `torch.empty_like` does not support `pin_memory` argument
|
||||
cpu_data = torch.empty_strided(
|
||||
size=p.data.size(),
|
||||
stride=p.data.stride(),
|
||||
dtype=p.data.dtype,
|
||||
layout=p.data.layout,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
cpu_data.copy_(p.data)
|
||||
p.data = cpu_data
|
||||
else:
|
||||
p.data = p.data.to(original_device)
|
||||
# New parameters or parameters already on target device are untouched
|
||||
p.data = p.data.to(original_device)
|
||||
|
||||
# parameter is UVA offloaded, but was replaced with a new device tensor
|
||||
# re-offload it to CPU using UVA
|
||||
if name in uva_offloaded_parameters and not getattr(
|
||||
p, "_vllm_is_uva_offloaded", False
|
||||
):
|
||||
cpu_data = p.data.to(device="cpu")
|
||||
if use_pin_memory:
|
||||
cpu_data = cpu_data.pin_memory()
|
||||
p.data = get_accelerator_view_from_cpu_tensor(cpu_data)
|
||||
p._vllm_is_uva_offloaded = True
|
||||
|
||||
|
||||
_MODEL_ARCH_BY_HASH = dict[int, tuple[type[nn.Module], str]]()
|
||||
|
||||
Reference in New Issue
Block a user