[Feature] Support CPU Offloading without Pytorch Pinned Memory that leads to doubled allocation (#32993)

Signed-off-by: wzhao18 <wzhao18.sz@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
Wei Zhao
2026-02-13 11:11:26 -05:00
committed by GitHub
parent 4a9952ec1b
commit 59d53066d8
6 changed files with 127 additions and 62 deletions

View File

@@ -13,6 +13,7 @@ from torch.func import functional_call
from torch.nn.modules.module import register_module_module_registration_hook
from transformers import PretrainedConfig
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.distributed import (
get_tensor_model_parallel_rank,
@@ -633,11 +634,10 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
return module
pin_memory = is_pin_memory_available()
uva_available = is_uva_available()
assert uva_available, "V1 CPU offloading requires uva (pin memory) support"
uva_offloading = True
pin_memory = (
is_pin_memory_available() and not envs.VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY
)
uva_offloading = is_uva_available() and not envs.VLLM_WEIGHT_OFFLOADING_DISABLE_UVA
# offload parameters to CPU
# use pin_memory if possible, which helps cudagraph capture speed
@@ -648,22 +648,16 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
# one module might have some parameters offloaded and some not
break
# `torch.empty_like` does not support `pin_memory` argument
cpu_data = torch.empty_strided(
size=p.data.size(),
stride=p.data.stride(),
dtype=p.data.dtype,
layout=p.data.layout,
device="cpu",
pin_memory=pin_memory,
)
cpu_data.copy_(p.data)
cpu_data = p.data.to(device="cpu")
if pin_memory:
cpu_data = cpu_data.pin_memory()
if not uva_offloading:
p.data = cpu_data
else:
# keep the cpu data alive
p._vllm_offloaded_cpu_data = cpu_data
p.data = get_accelerator_view_from_cpu_tensor(cpu_data)
p._vllm_is_uva_offloaded = True
_CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
offloaded_parameters = True
@@ -678,7 +672,12 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
k: v.to(device, non_blocking=True)
for k, v in module.state_dict().items()
}
output = functional_call(module, device_state, args=args, kwargs=kwargs)
# set `tie_weights=False` as tied weights in original model
# become untied when calling .to(device) individually
output = functional_call(
module, device_state, args=args, kwargs=kwargs, tie_weights=False
)
module.forward = forward
return output