[Performance] Auto-enable prefetch on NFS with RAM guard (#37673)

Signed-off-by: Artem Perevedentsev <aperevedents@nvidia.com>
This commit is contained in:
Artem Perevedentsev
2026-03-25 02:31:14 +02:00
committed by GitHub
parent 679c6a3ecc
commit a93a53f8a1
3 changed files with 69 additions and 5 deletions

View File

@@ -54,10 +54,14 @@ class LoadConfig:
download_dir: str | None = None
"""Directory to download and load the weights, default to the default
cache directory of Hugging Face."""
safetensors_load_strategy: str = "lazy"
safetensors_load_strategy: str | None = None
"""Specifies the loading strategy for safetensors weights.
- "lazy" (default): Weights are memory-mapped from the file. This enables
- None (default): Uses memory-mapped (lazy) loading. When an NFS
filesystem is detected and the total checkpoint size fits within 90%%
of available RAM, prefetching is enabled automatically.
- "lazy": Weights are memory-mapped from the file. This enables
on-demand loading and is highly efficient for models on local storage.
Unlike the default (None), auto-prefetch on NFS is not performed.
- "eager": The entire file is read into CPU memory upfront before loading.
This is recommended for models on network filesystems (e.g., Lustre, NFS)
as it avoids inefficient random reads, significantly speeding up model

View File

@@ -388,7 +388,7 @@ class EngineArgs:
allowed_local_media_path: str = ModelConfig.allowed_local_media_path
allowed_media_domains: list[str] | None = ModelConfig.allowed_media_domains
download_dir: str | None = LoadConfig.download_dir
safetensors_load_strategy: str = LoadConfig.safetensors_load_strategy
safetensors_load_strategy: str | None = LoadConfig.safetensors_load_strategy
load_format: str | LoadFormats = LoadConfig.load_format
config_format: str = ModelConfig.config_format
dtype: ModelDType = ModelConfig.dtype

View File

@@ -729,6 +729,61 @@ def np_cache_weights_iterator(
yield name, torch.from_numpy(param)
def _checkpoints_fit_in_ram(files: list[str], threshold: float = 0.9) -> bool:
"""Return True if total size of *files* fits within *threshold* of available RAM."""
if not files:
return True
import psutil
total_size = sum(os.path.getsize(f) for f in files)
available_ram = psutil.virtual_memory().available
fits = total_size <= threshold * available_ram
if not fits:
logger.warning(
"NFS detected but checkpoint total size (%.2f GiB) exceeds "
"%.0f%% of available RAM (%.2f GiB). Skipping prefetching checkpoints.",
total_size / (1024**3),
threshold * 100,
available_ram / (1024**3),
)
return fits
def _is_nfs_path(files: list[str]) -> bool:
"""Check whether the first file in *files* resides on an NFS
filesystem (Linux only)."""
if not files:
return False
try:
# Only the first file is checked — all checkpoint shards reside
# in the same directory and therefore on the same filesystem.
resolved = os.path.realpath(files[0])
best_mount = ""
best_fstype = ""
# /proc/mounts may contain nested mount points (e.g. "/" -> ext4,
# "/data" -> nfs4, "/data/local" -> ext4). We pick the entry with
# the longest matching mount_point — the same "longest prefix match"
# rule the kernel uses to decide which filesystem serves a path.
with open("/proc/mounts") as f:
for line in f:
parts = line.split()
if len(parts) < 3:
continue
mount_point, fstype = parts[1], parts[2]
if (
resolved == mount_point
or resolved.startswith(os.path.join(mount_point, ""))
) and len(mount_point) > len(best_mount):
best_mount = mount_point
best_fstype = fstype
return best_fstype in ("nfs", "nfs4")
except Exception:
# /proc/mounts is Linux-specific; on other OSes (or if the read
# fails for any reason) we fall back to "not NFS" rather than
# crashing model loading.
return False
def _prefetch_checkpoint(file_path: str) -> None:
"""Prefetch a checkpoint file into the OS page cache.
@@ -797,7 +852,7 @@ def _prefetch_all_checkpoints(sorted_files: list[str]) -> None:
def safetensors_weights_iterator(
hf_weights_files: list[str],
use_tqdm_on_load: bool,
safetensors_load_strategy: str = "lazy",
safetensors_load_strategy: str | None = None,
local_expert_ids: set[int] | None = None,
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files.
@@ -812,7 +867,12 @@ def safetensors_weights_iterator(
sorted_files = sorted(hf_weights_files, key=_natural_sort_key)
if safetensors_load_strategy == "prefetch":
should_prefetch = safetensors_load_strategy == "prefetch" or (
safetensors_load_strategy is None
and _is_nfs_path(sorted_files)
and _checkpoints_fit_in_ram(sorted_files)
)
if should_prefetch:
_prefetch_all_checkpoints(sorted_files)
leftover_state_dict: dict[str, torch.Tensor] = {}