[Performance] Add prefetch for checkpoints to OS page cache (#36012)

Signed-off-by: Artem Perevedentsev <aperevedents@nvidia.com>
This commit is contained in:
Artem Perevedentsev
2026-03-16 13:32:02 +02:00
committed by GitHub
parent 9b005edc48
commit f5e59ee7a6
2 changed files with 76 additions and 1 deletions

View File

@@ -62,6 +62,9 @@ class LoadConfig:
This is recommended for models on network filesystems (e.g., Lustre, NFS)
as it avoids inefficient random reads, significantly speeding up model
initialization. However, it uses more CPU RAM.
- "prefetch": Checkpoint files are read into the OS page cache before
workers load them, speeding up the model loading phase. Useful on
network or high-latency storage.
- "torchao": Weights are loaded in upfront and then reconstructed
into torchao tensor subclasses. This is used when the checkpoint
was quantized using torchao and saved using safetensors.

View File

@@ -2,6 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utilities for downloading and initializing model weights."""
import asyncio
import concurrent.futures
import fnmatch
import glob
@@ -9,6 +10,7 @@ import hashlib
import json
import os
import tempfile
import threading
import time
from collections import defaultdict
from collections.abc import Callable, Generator
@@ -720,6 +722,71 @@ def np_cache_weights_iterator(
yield name, torch.from_numpy(param)
def _prefetch_checkpoint(file_path: str) -> None:
"""Prefetch a checkpoint file into the OS page cache.
Reads the file in 16MB blocks so the kernel caches its pages before
workers load the same file.
"""
block_size = 16 * 1024 * 1024 # 16MB
with open(file_path, "rb") as f:
while f.read(block_size):
pass
def _prefetch_all_checkpoints(sorted_files: list[str]) -> None:
"""Start prefetching checkpoint files into page cache in a background thread."""
if torch.distributed.is_initialized():
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
else:
rank = 0
world_size = 1
num_prefetch_threads = 8
paths_to_prefetch = sorted_files[rank::world_size]
total_for_rank = len(paths_to_prefetch)
async def _prefetch_all() -> None:
semaphore = asyncio.Semaphore(num_prefetch_threads)
completed = 0
next_log_pct = 10
async def prefetch_one(path: str) -> None:
nonlocal completed, next_log_pct
try:
async with semaphore:
await asyncio.to_thread(_prefetch_checkpoint, path)
completed += 1
if total_for_rank > 0 and next_log_pct <= 100:
pct = 100 * completed / total_for_rank
if pct >= next_log_pct:
logger.info(
"Prefetching checkpoint files: %d%% (%d/%d)",
next_log_pct,
completed,
total_for_rank,
)
next_log_pct += 10
except Exception:
logger.warning(
"Failed to prefetch checkpoint file %r.", path, exc_info=True
)
await asyncio.gather(*(prefetch_one(p) for p in paths_to_prefetch))
def _run_prefetch() -> None:
start = time.perf_counter()
asyncio.run(_prefetch_all())
elapsed = time.perf_counter() - start
logger.info(
"Prefetching checkpoint files into page cache finished in %.2fs",
elapsed,
)
logger.info("Prefetching checkpoint files into page cache started (in background)")
threading.Thread(target=_run_prefetch, daemon=True).start()
def safetensors_weights_iterator(
hf_weights_files: list[str],
use_tqdm_on_load: bool,
@@ -736,9 +803,14 @@ def safetensors_weights_iterator(
if safetensors_load_strategy == "eager":
loading_desc += " (eager)"
sorted_files = sorted(hf_weights_files, key=_natural_sort_key)
if safetensors_load_strategy == "prefetch":
_prefetch_all_checkpoints(sorted_files)
leftover_state_dict: dict[str, torch.Tensor] = {}
for st_file in tqdm(
sorted(hf_weights_files, key=_natural_sort_key),
sorted_files,
desc=loading_desc,
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,