[V1] Fully Transparent Implementation of CPU Offloading (#15354)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -10,12 +10,14 @@ import torch.nn as nn
|
||||
from torch.func import functional_call
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_pin_memory_available
|
||||
from vllm.utils import (get_cuda_view_from_cpu_tensor, is_pin_memory_available,
|
||||
is_uva_available)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -505,6 +507,14 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
|
||||
return module
|
||||
|
||||
pin_memory = is_pin_memory_available()
|
||||
uva_available = is_uva_available()
|
||||
|
||||
if envs.VLLM_USE_V1:
|
||||
assert uva_available, ("V1 CPU offloading requires"
|
||||
" uva (pin memory) support")
|
||||
uva_offloading = True
|
||||
else:
|
||||
uva_offloading = False
|
||||
|
||||
# offload parameters to CPU
|
||||
# use pin_memory if possible, which helps cudagraph capture speed
|
||||
@@ -523,11 +533,16 @@ def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
|
||||
device='cpu',
|
||||
pin_memory=pin_memory)
|
||||
cpu_data.copy_(p.data)
|
||||
p.data = cpu_data
|
||||
if not uva_offloading:
|
||||
p.data = cpu_data
|
||||
else:
|
||||
# keep the cpu data alive
|
||||
p._vllm_offloaded_cpu_data = cpu_data
|
||||
p.data = get_cuda_view_from_cpu_tensor(cpu_data)
|
||||
_CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
|
||||
offloaded_parameters = True
|
||||
|
||||
if offloaded_parameters:
|
||||
if offloaded_parameters and not uva_offloading:
|
||||
original_forward = module.forward
|
||||
|
||||
def forward(*args, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user