From f5e59ee7a6c3a07aad8f814b261bc0a1db2dcaf1 Mon Sep 17 00:00:00 2001 From: Artem Perevedentsev Date: Mon, 16 Mar 2026 13:32:02 +0200 Subject: [PATCH] [Performance] Add prefetch for checkpoints to OS page cache (#36012) Signed-off-by: Artem Perevedentsev --- vllm/config/load.py | 3 + .../model_loader/weight_utils.py | 74 ++++++++++++++++++- 2 files changed, 76 insertions(+), 1 deletion(-) diff --git a/vllm/config/load.py b/vllm/config/load.py index b771556d8..c36c1adfe 100644 --- a/vllm/config/load.py +++ b/vllm/config/load.py @@ -62,6 +62,9 @@ class LoadConfig: This is recommended for models on network filesystems (e.g., Lustre, NFS) as it avoids inefficient random reads, significantly speeding up model initialization. However, it uses more CPU RAM. + - "prefetch": Checkpoint files are read into the OS page cache before + workers load them, speeding up the model loading phase. Useful on + network or high-latency storage. - "torchao": Weights are loaded in upfront and then reconstructed into torchao tensor subclasses. This is used when the checkpoint was quantized using torchao and saved using safetensors. diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 0a67a6a42..dd4bf636e 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Utilities for downloading and initializing model weights.""" +import asyncio import concurrent.futures import fnmatch import glob @@ -9,6 +10,7 @@ import hashlib import json import os import tempfile +import threading import time from collections import defaultdict from collections.abc import Callable, Generator @@ -720,6 +722,71 @@ def np_cache_weights_iterator( yield name, torch.from_numpy(param) +def _prefetch_checkpoint(file_path: str) -> None: + """Prefetch a checkpoint file into the OS page cache. + + Reads the file in 16MB blocks so the kernel caches its pages before + workers load the same file. + """ + block_size = 16 * 1024 * 1024 # 16MB + with open(file_path, "rb") as f: + while f.read(block_size): + pass + + +def _prefetch_all_checkpoints(sorted_files: list[str]) -> None: + """Start prefetching checkpoint files into page cache in a background thread.""" + if torch.distributed.is_initialized(): + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + else: + rank = 0 + world_size = 1 + num_prefetch_threads = 8 + paths_to_prefetch = sorted_files[rank::world_size] + total_for_rank = len(paths_to_prefetch) + + async def _prefetch_all() -> None: + semaphore = asyncio.Semaphore(num_prefetch_threads) + completed = 0 + next_log_pct = 10 + + async def prefetch_one(path: str) -> None: + nonlocal completed, next_log_pct + try: + async with semaphore: + await asyncio.to_thread(_prefetch_checkpoint, path) + completed += 1 + if total_for_rank > 0 and next_log_pct <= 100: + pct = 100 * completed / total_for_rank + if pct >= next_log_pct: + logger.info( + "Prefetching checkpoint files: %d%% (%d/%d)", + next_log_pct, + completed, + total_for_rank, + ) + next_log_pct += 10 + except Exception: + logger.warning( + "Failed to prefetch checkpoint file %r.", path, exc_info=True + ) + + await asyncio.gather(*(prefetch_one(p) for p in paths_to_prefetch)) + + def _run_prefetch() -> None: + start = time.perf_counter() + asyncio.run(_prefetch_all()) + elapsed = time.perf_counter() - start + logger.info( + "Prefetching checkpoint files into page cache finished in %.2fs", + elapsed, + ) + + logger.info("Prefetching checkpoint files into page cache started (in background)") + threading.Thread(target=_run_prefetch, daemon=True).start() + + def safetensors_weights_iterator( hf_weights_files: list[str], use_tqdm_on_load: bool, @@ -736,9 +803,14 @@ def safetensors_weights_iterator( if safetensors_load_strategy == "eager": loading_desc += " (eager)" + sorted_files = sorted(hf_weights_files, key=_natural_sort_key) + + if safetensors_load_strategy == "prefetch": + _prefetch_all_checkpoints(sorted_files) + leftover_state_dict: dict[str, torch.Tensor] = {} for st_file in tqdm( - sorted(hf_weights_files, key=_natural_sort_key), + sorted_files, desc=loading_desc, disable=not enable_tqdm(use_tqdm_on_load), bar_format=_BAR_FORMAT,