[Feature] Support CPU Offloading without Pytorch Pinned Memory that leads to doubled allocation (#32993)

Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Wei Zhao
2026-02-13 11:11:26 -05:00
committed by GitHub
parent 4a9952ec1b
commit 59d53066d8
6 changed files with 127 additions and 62 deletions

View File

@@ -11,6 +11,7 @@ import torch
from torch import nn
from typing_extensions import assert_never
import vllm.envs as envs
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention, MLAAttention
@@ -25,6 +26,7 @@ from vllm.model_executor.model_loader.reload import (
from vllm.model_executor.models.interfaces import SupportsQuant
from vllm.tracing import instrument
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.utils.torch_utils import get_accelerator_view_from_cpu_tensor
logger = init_logger(__name__)
@@ -111,7 +113,8 @@ def process_weights_after_loading(
):
# TODO(lucas): see if there is a way to unify the signatures
# of process_weights_after_loading
module.process_weights_after_loading(model_config.dtype)
with device_loading_context(module, target_device):
module.process_weights_after_loading(model_config.dtype)
# Needed for torchao model reloading via model.reload_weights
# @kylesayrs @jerryzh168 this can be removed if callers move to `reload_weights`
@@ -127,38 +130,41 @@ def device_loading_context(module: torch.nn.Module, target_device: torch.device)
return
original_device_states: dict[str, torch.device] = {}
uva_offloaded_parameters: list[str] = []
# Store original device states and move parameters to GPU if they're on CPU
for name, p in module.named_parameters():
if p.device.type == "cpu":
original_device_states[name] = p.device
p.data = p.data.to(target_device)
if getattr(p, "_vllm_is_uva_offloaded", False):
uva_offloaded_parameters.append(name)
# Parameters already on target device are not touched
try:
yield module
finally:
use_pin_memory = (
is_pin_memory_available()
and not envs.VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY
)
# Restore parameters to their original devices, ignoring new parameters
pin_memory = is_pin_memory_available()
for name, p in module.named_parameters():
if name in original_device_states:
original_device: torch.device = original_device_states[name]
if original_device.type == "cpu":
# `torch.empty_like` does not support `pin_memory` argument
cpu_data = torch.empty_strided(
size=p.data.size(),
stride=p.data.stride(),
dtype=p.data.dtype,
layout=p.data.layout,
device="cpu",
pin_memory=pin_memory,
)
cpu_data.copy_(p.data)
p.data = cpu_data
else:
p.data = p.data.to(original_device)
# New parameters or parameters already on target device are untouched
p.data = p.data.to(original_device)
# parameter is UVA offloaded, but was replaced with a new device tensor
# re-offload it to CPU using UVA
if name in uva_offloaded_parameters and not getattr(
p, "_vllm_is_uva_offloaded", False
):
cpu_data = p.data.to(device="cpu")
if use_pin_memory:
cpu_data = cpu_data.pin_memory()
p.data = get_accelerator_view_from_cpu_tensor(cpu_data)
p._vllm_is_uva_offloaded = True
_MODEL_ARCH_BY_HASH = dict[int, tuple[type[nn.Module], str]]()