[offloader] v2: Hide weight onloading latency via prefetching (#29941)
Signed-off-by: Ming Yang <minos.future@gmail.com> Signed-off-by: Michael Goin <mgoin64@gmail.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
This commit is contained in:
@@ -9,11 +9,9 @@ from typing import Any, Literal, Protocol, overload
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.func import functional_call
|
||||
from torch.nn.modules.module import register_module_module_registration_hook
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
@@ -31,14 +29,11 @@ from vllm.model_executor.models.interfaces import supports_any_eagle
|
||||
from vllm.multimodal import NestedTensors
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.mem_utils import format_gib
|
||||
from vllm.utils.platform_utils import (
|
||||
is_pin_memory_available,
|
||||
is_uva_available,
|
||||
)
|
||||
from vllm.utils.torch_utils import (
|
||||
direct_register_custom_op,
|
||||
get_accelerator_view_from_cpu_tensor,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -612,98 +607,6 @@ class PPMissingLayer(torch.nn.Identity):
|
||||
return args[0] if args else next(iter(kwargs.values()))
|
||||
|
||||
|
||||
_CPU_OFFLOAD_BYTES = 0
|
||||
_CPU_OFFLOAD_MAX_BYTES = 0
|
||||
_CPU_OFFLOAD_PARAMS = set()
|
||||
|
||||
|
||||
def set_cpu_offload_max_bytes(max_bytes: int) -> None:
|
||||
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
|
||||
_CPU_OFFLOAD_BYTES = 0
|
||||
_CPU_OFFLOAD_MAX_BYTES = max_bytes
|
||||
|
||||
|
||||
def set_cpu_offload_params(params: set[str]) -> None:
|
||||
global _CPU_OFFLOAD_PARAMS
|
||||
_CPU_OFFLOAD_PARAMS = params
|
||||
|
||||
|
||||
def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
|
||||
if (params := next(module.parameters(), None)) is None:
|
||||
return module
|
||||
|
||||
device = params.device
|
||||
|
||||
if device == torch.device("cpu"):
|
||||
return module
|
||||
|
||||
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
|
||||
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
|
||||
return module
|
||||
|
||||
pin_memory = (
|
||||
is_pin_memory_available() and not envs.VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY
|
||||
)
|
||||
uva_offloading = is_uva_available() and not envs.VLLM_WEIGHT_OFFLOADING_DISABLE_UVA
|
||||
|
||||
# offload parameters to CPU
|
||||
# use pin_memory if possible, which helps cudagraph capture speed
|
||||
offloaded_parameters = False
|
||||
for name, p in module.named_parameters():
|
||||
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
|
||||
# we use per-parameter offloading
|
||||
# one module might have some parameters offloaded and some not
|
||||
break
|
||||
|
||||
if _CPU_OFFLOAD_PARAMS:
|
||||
# Check if parameter belongs to the offloading set
|
||||
# Add dots here to ensure we match full segments only
|
||||
# e.g., "experts.w2_weight" matches "mlp.experts.w2_weight" but not
|
||||
# "mlp.experts.w2_weight_scale"
|
||||
should_offload = any(
|
||||
f".{param}." in f".{name}." for param in _CPU_OFFLOAD_PARAMS
|
||||
)
|
||||
if not should_offload:
|
||||
continue
|
||||
|
||||
cpu_data = p.data.to(device="cpu")
|
||||
if pin_memory:
|
||||
cpu_data = cpu_data.pin_memory()
|
||||
|
||||
if not uva_offloading:
|
||||
p.data = cpu_data
|
||||
else:
|
||||
p.data = get_accelerator_view_from_cpu_tensor(cpu_data)
|
||||
p._vllm_is_uva_offloaded = True
|
||||
|
||||
_CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
|
||||
offloaded_parameters = True
|
||||
|
||||
if offloaded_parameters and not uva_offloading:
|
||||
original_forward = module.forward
|
||||
|
||||
def forward(*args, **kwargs):
|
||||
module.forward = original_forward
|
||||
device_state = {
|
||||
# here we blindly call `to(device)`
|
||||
# if the parameter is already on the device, it will be a no-op
|
||||
k: v.to(device, non_blocking=True)
|
||||
for k, v in module.state_dict().items()
|
||||
}
|
||||
|
||||
# set `tie_weights=False` as tied weights in original model
|
||||
# become untied when calling .to(device) individually
|
||||
output = functional_call(
|
||||
module, device_state, args=args, kwargs=kwargs, tie_weights=False
|
||||
)
|
||||
module.forward = forward
|
||||
return output
|
||||
|
||||
module.forward = forward
|
||||
|
||||
return module
|
||||
|
||||
|
||||
def make_layers(
|
||||
num_hidden_layers: int,
|
||||
layer_fn: LayerFn,
|
||||
@@ -711,25 +614,31 @@ def make_layers(
|
||||
) -> tuple[int, int, torch.nn.ModuleList]:
|
||||
"""Make a list of layers with the given layer function, taking
|
||||
pipeline parallelism into account.
|
||||
|
||||
Args:
|
||||
num_hidden_layers: Total number of hidden layers in the model.
|
||||
layer_fn: Function to create a layer given its index.
|
||||
prefix: Prefix for layer names.
|
||||
|
||||
Returns:
|
||||
Tuple of (start_layer, end_layer, modules).
|
||||
"""
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.distributed.utils import get_pp_indices
|
||||
from vllm.model_executor.offloader import get_offloader
|
||||
|
||||
start_layer, end_layer = get_pp_indices(
|
||||
num_hidden_layers, get_pp_group().rank_in_group, get_pp_group().world_size
|
||||
)
|
||||
|
||||
modules = torch.nn.ModuleList(
|
||||
[PPMissingLayer() for _ in range(start_layer)]
|
||||
+ [
|
||||
maybe_offload_to_cpu(layer_fn(prefix=f"{prefix}.{idx}"))
|
||||
for idx in range(start_layer, end_layer)
|
||||
]
|
||||
+ get_offloader().wrap_modules(
|
||||
layer_fn(prefix=f"{prefix}.{idx}") for idx in range(start_layer, end_layer)
|
||||
)
|
||||
+ [PPMissingLayer() for _ in range(end_layer, num_hidden_layers)]
|
||||
)
|
||||
if _CPU_OFFLOAD_MAX_BYTES > 0:
|
||||
logger.info(
|
||||
"Total CPU offloaded parameters: %s GBs", format_gib(_CPU_OFFLOAD_BYTES)
|
||||
)
|
||||
|
||||
return start_layer, end_layer, modules
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user