[Performance] Auto-enable prefetch on NFS with RAM guard (#37673)

Signed-off-by: Artem Perevedentsev <aperevedents@nvidia.com>
This commit is contained in:
Artem Perevedentsev
2026-03-25 02:31:14 +02:00
committed by GitHub
parent 679c6a3ecc
commit a93a53f8a1
3 changed files with 69 additions and 5 deletions

View File

@@ -729,6 +729,61 @@ def np_cache_weights_iterator(
yield name, torch.from_numpy(param)
def _checkpoints_fit_in_ram(files: list[str], threshold: float = 0.9) -> bool:
"""Return True if total size of *files* fits within *threshold* of available RAM."""
if not files:
return True
import psutil
total_size = sum(os.path.getsize(f) for f in files)
available_ram = psutil.virtual_memory().available
fits = total_size <= threshold * available_ram
if not fits:
logger.warning(
"NFS detected but checkpoint total size (%.2f GiB) exceeds "
"%.0f%% of available RAM (%.2f GiB). Skipping prefetching checkpoints.",
total_size / (1024**3),
threshold * 100,
available_ram / (1024**3),
)
return fits
def _is_nfs_path(files: list[str]) -> bool:
"""Check whether the first file in *files* resides on an NFS
filesystem (Linux only)."""
if not files:
return False
try:
# Only the first file is checked — all checkpoint shards reside
# in the same directory and therefore on the same filesystem.
resolved = os.path.realpath(files[0])
best_mount = ""
best_fstype = ""
# /proc/mounts may contain nested mount points (e.g. "/" -> ext4,
# "/data" -> nfs4, "/data/local" -> ext4). We pick the entry with
# the longest matching mount_point — the same "longest prefix match"
# rule the kernel uses to decide which filesystem serves a path.
with open("/proc/mounts") as f:
for line in f:
parts = line.split()
if len(parts) < 3:
continue
mount_point, fstype = parts[1], parts[2]
if (
resolved == mount_point
or resolved.startswith(os.path.join(mount_point, ""))
) and len(mount_point) > len(best_mount):
best_mount = mount_point
best_fstype = fstype
return best_fstype in ("nfs", "nfs4")
except Exception:
# /proc/mounts is Linux-specific; on other OSes (or if the read
# fails for any reason) we fall back to "not NFS" rather than
# crashing model loading.
return False
def _prefetch_checkpoint(file_path: str) -> None:
"""Prefetch a checkpoint file into the OS page cache.
@@ -797,7 +852,7 @@ def _prefetch_all_checkpoints(sorted_files: list[str]) -> None:
def safetensors_weights_iterator(
hf_weights_files: list[str],
use_tqdm_on_load: bool,
safetensors_load_strategy: str = "lazy",
safetensors_load_strategy: str | None = None,
local_expert_ids: set[int] | None = None,
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files.
@@ -812,7 +867,12 @@ def safetensors_weights_iterator(
sorted_files = sorted(hf_weights_files, key=_natural_sort_key)
if safetensors_load_strategy == "prefetch":
should_prefetch = safetensors_load_strategy == "prefetch" or (
safetensors_load_strategy is None
and _is_nfs_path(sorted_files)
and _checkpoints_fit_in_ram(sorted_files)
)
if should_prefetch:
_prefetch_all_checkpoints(sorted_files)
leftover_state_dict: dict[str, torch.Tensor] = {}