[core][model] yet another cpu offload implementation (#6496)
Co-authored-by: Michael Goin <michael@neuralmagic.com>
This commit is contained in:
@@ -1,8 +1,10 @@
|
||||
from typing import Callable, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from torch.func import functional_call
|
||||
|
||||
from vllm.multimodal import BatchedTensors
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
|
||||
def merge_vision_embeddings(input_ids: torch.Tensor,
|
||||
@@ -52,6 +54,70 @@ class PPMissingLayer(torch.nn.Identity):
|
||||
super().__init__()
|
||||
|
||||
|
||||
_CPU_OFFLOAD_BYTES = 0
|
||||
_CPU_OFFLOAD_MAX_BYTES = 0
|
||||
|
||||
|
||||
def set_cpu_offload_max_bytes(max_bytes: int) -> None:
|
||||
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
|
||||
_CPU_OFFLOAD_BYTES = 0
|
||||
_CPU_OFFLOAD_MAX_BYTES = max_bytes
|
||||
|
||||
|
||||
def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
|
||||
device = next(module.parameters()).device
|
||||
|
||||
if device == torch.device("cpu"):
|
||||
return module
|
||||
|
||||
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
|
||||
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
|
||||
return module
|
||||
|
||||
pin_memory = is_pin_memory_available()
|
||||
|
||||
# offload parameters to CPU
|
||||
# use pin_memory if possible, which helps cudagraph capture speed
|
||||
for p in module.parameters():
|
||||
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
|
||||
# we use per-parameter offloading
|
||||
# one module might have some parameters offloaded and some not
|
||||
break
|
||||
|
||||
# `torch.empty_like` does not support `pin_memory` argument
|
||||
cpu_data = torch.empty(size=p.data.size(),
|
||||
dtype=p.data.dtype,
|
||||
layout=p.data.layout,
|
||||
device='cpu',
|
||||
pin_memory=pin_memory)
|
||||
cpu_data.copy_(p.data)
|
||||
p.data = cpu_data
|
||||
_CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
|
||||
|
||||
state_dict: Dict[str, torch.Tensor] = module.state_dict()
|
||||
|
||||
original_forward = module.forward
|
||||
|
||||
def forward(*args, **kwargs):
|
||||
module.forward = original_forward
|
||||
device_state = {
|
||||
# here we blindly call `to(device)`
|
||||
# if the parameter is already on the device, it will be a no-op
|
||||
k: v.to(device, non_blocking=True)
|
||||
for k, v in state_dict.items()
|
||||
}
|
||||
output = functional_call(module,
|
||||
device_state,
|
||||
args=args,
|
||||
kwargs=kwargs)
|
||||
module.forward = forward
|
||||
return output
|
||||
|
||||
module.forward = forward
|
||||
|
||||
return module
|
||||
|
||||
|
||||
def make_layers(
|
||||
num_hidden_layers: int, layer_fn: Callable[[], torch.nn.Module]
|
||||
) -> Tuple[int, int, torch.nn.ModuleList]:
|
||||
@@ -64,9 +130,10 @@ def make_layers(
|
||||
get_pp_group().rank_in_group,
|
||||
get_pp_group().world_size)
|
||||
modules = torch.nn.ModuleList(
|
||||
[PPMissingLayer() for _ in range(start_layer)] +
|
||||
[layer_fn() for _ in range(start_layer, end_layer)] +
|
||||
[PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
|
||||
[PPMissingLayer() for _ in range(start_layer)] + [
|
||||
maybe_offload_to_cpu(layer_fn())
|
||||
for _ in range(start_layer, end_layer)
|
||||
] + [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
|
||||
return start_layer, end_layer, modules
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user