[Performance] Add prefetch for checkpoints to OS page cache (#36012)
Signed-off-by: Artem Perevedentsev <aperevedents@nvidia.com>
This commit is contained in:
committed by
GitHub
parent
9b005edc48
commit
f5e59ee7a6
@@ -62,6 +62,9 @@ class LoadConfig:
|
||||
This is recommended for models on network filesystems (e.g., Lustre, NFS)
|
||||
as it avoids inefficient random reads, significantly speeding up model
|
||||
initialization. However, it uses more CPU RAM.
|
||||
- "prefetch": Checkpoint files are read into the OS page cache before
|
||||
workers load them, speeding up the model loading phase. Useful on
|
||||
network or high-latency storage.
|
||||
- "torchao": Weights are loaded in upfront and then reconstructed
|
||||
into torchao tensor subclasses. This is used when the checkpoint
|
||||
was quantized using torchao and saved using safetensors.
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Utilities for downloading and initializing model weights."""
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import fnmatch
|
||||
import glob
|
||||
@@ -9,6 +10,7 @@ import hashlib
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Generator
|
||||
@@ -720,6 +722,71 @@ def np_cache_weights_iterator(
|
||||
yield name, torch.from_numpy(param)
|
||||
|
||||
|
||||
def _prefetch_checkpoint(file_path: str) -> None:
|
||||
"""Prefetch a checkpoint file into the OS page cache.
|
||||
|
||||
Reads the file in 16MB blocks so the kernel caches its pages before
|
||||
workers load the same file.
|
||||
"""
|
||||
block_size = 16 * 1024 * 1024 # 16MB
|
||||
with open(file_path, "rb") as f:
|
||||
while f.read(block_size):
|
||||
pass
|
||||
|
||||
|
||||
def _prefetch_all_checkpoints(sorted_files: list[str]) -> None:
|
||||
"""Start prefetching checkpoint files into page cache in a background thread."""
|
||||
if torch.distributed.is_initialized():
|
||||
rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
else:
|
||||
rank = 0
|
||||
world_size = 1
|
||||
num_prefetch_threads = 8
|
||||
paths_to_prefetch = sorted_files[rank::world_size]
|
||||
total_for_rank = len(paths_to_prefetch)
|
||||
|
||||
async def _prefetch_all() -> None:
|
||||
semaphore = asyncio.Semaphore(num_prefetch_threads)
|
||||
completed = 0
|
||||
next_log_pct = 10
|
||||
|
||||
async def prefetch_one(path: str) -> None:
|
||||
nonlocal completed, next_log_pct
|
||||
try:
|
||||
async with semaphore:
|
||||
await asyncio.to_thread(_prefetch_checkpoint, path)
|
||||
completed += 1
|
||||
if total_for_rank > 0 and next_log_pct <= 100:
|
||||
pct = 100 * completed / total_for_rank
|
||||
if pct >= next_log_pct:
|
||||
logger.info(
|
||||
"Prefetching checkpoint files: %d%% (%d/%d)",
|
||||
next_log_pct,
|
||||
completed,
|
||||
total_for_rank,
|
||||
)
|
||||
next_log_pct += 10
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Failed to prefetch checkpoint file %r.", path, exc_info=True
|
||||
)
|
||||
|
||||
await asyncio.gather(*(prefetch_one(p) for p in paths_to_prefetch))
|
||||
|
||||
def _run_prefetch() -> None:
|
||||
start = time.perf_counter()
|
||||
asyncio.run(_prefetch_all())
|
||||
elapsed = time.perf_counter() - start
|
||||
logger.info(
|
||||
"Prefetching checkpoint files into page cache finished in %.2fs",
|
||||
elapsed,
|
||||
)
|
||||
|
||||
logger.info("Prefetching checkpoint files into page cache started (in background)")
|
||||
threading.Thread(target=_run_prefetch, daemon=True).start()
|
||||
|
||||
|
||||
def safetensors_weights_iterator(
|
||||
hf_weights_files: list[str],
|
||||
use_tqdm_on_load: bool,
|
||||
@@ -736,9 +803,14 @@ def safetensors_weights_iterator(
|
||||
if safetensors_load_strategy == "eager":
|
||||
loading_desc += " (eager)"
|
||||
|
||||
sorted_files = sorted(hf_weights_files, key=_natural_sort_key)
|
||||
|
||||
if safetensors_load_strategy == "prefetch":
|
||||
_prefetch_all_checkpoints(sorted_files)
|
||||
|
||||
leftover_state_dict: dict[str, torch.Tensor] = {}
|
||||
for st_file in tqdm(
|
||||
sorted(hf_weights_files, key=_natural_sort_key),
|
||||
sorted_files,
|
||||
desc=loading_desc,
|
||||
disable=not enable_tqdm(use_tqdm_on_load),
|
||||
bar_format=_BAR_FORMAT,
|
||||
|
||||
Reference in New Issue
Block a user