[Performance] Auto-enable prefetch on NFS with RAM guard (#37673)
Signed-off-by: Artem Perevedentsev <aperevedents@nvidia.com>
This commit is contained in:
committed by
GitHub
parent
679c6a3ecc
commit
a93a53f8a1
@@ -729,6 +729,61 @@ def np_cache_weights_iterator(
|
||||
yield name, torch.from_numpy(param)
|
||||
|
||||
|
||||
def _checkpoints_fit_in_ram(files: list[str], threshold: float = 0.9) -> bool:
|
||||
"""Return True if total size of *files* fits within *threshold* of available RAM."""
|
||||
if not files:
|
||||
return True
|
||||
import psutil
|
||||
|
||||
total_size = sum(os.path.getsize(f) for f in files)
|
||||
available_ram = psutil.virtual_memory().available
|
||||
fits = total_size <= threshold * available_ram
|
||||
if not fits:
|
||||
logger.warning(
|
||||
"NFS detected but checkpoint total size (%.2f GiB) exceeds "
|
||||
"%.0f%% of available RAM (%.2f GiB). Skipping prefetching checkpoints.",
|
||||
total_size / (1024**3),
|
||||
threshold * 100,
|
||||
available_ram / (1024**3),
|
||||
)
|
||||
return fits
|
||||
|
||||
|
||||
def _is_nfs_path(files: list[str]) -> bool:
|
||||
"""Check whether the first file in *files* resides on an NFS
|
||||
filesystem (Linux only)."""
|
||||
if not files:
|
||||
return False
|
||||
try:
|
||||
# Only the first file is checked — all checkpoint shards reside
|
||||
# in the same directory and therefore on the same filesystem.
|
||||
resolved = os.path.realpath(files[0])
|
||||
best_mount = ""
|
||||
best_fstype = ""
|
||||
# /proc/mounts may contain nested mount points (e.g. "/" -> ext4,
|
||||
# "/data" -> nfs4, "/data/local" -> ext4). We pick the entry with
|
||||
# the longest matching mount_point — the same "longest prefix match"
|
||||
# rule the kernel uses to decide which filesystem serves a path.
|
||||
with open("/proc/mounts") as f:
|
||||
for line in f:
|
||||
parts = line.split()
|
||||
if len(parts) < 3:
|
||||
continue
|
||||
mount_point, fstype = parts[1], parts[2]
|
||||
if (
|
||||
resolved == mount_point
|
||||
or resolved.startswith(os.path.join(mount_point, ""))
|
||||
) and len(mount_point) > len(best_mount):
|
||||
best_mount = mount_point
|
||||
best_fstype = fstype
|
||||
return best_fstype in ("nfs", "nfs4")
|
||||
except Exception:
|
||||
# /proc/mounts is Linux-specific; on other OSes (or if the read
|
||||
# fails for any reason) we fall back to "not NFS" rather than
|
||||
# crashing model loading.
|
||||
return False
|
||||
|
||||
|
||||
def _prefetch_checkpoint(file_path: str) -> None:
|
||||
"""Prefetch a checkpoint file into the OS page cache.
|
||||
|
||||
@@ -797,7 +852,7 @@ def _prefetch_all_checkpoints(sorted_files: list[str]) -> None:
|
||||
def safetensors_weights_iterator(
|
||||
hf_weights_files: list[str],
|
||||
use_tqdm_on_load: bool,
|
||||
safetensors_load_strategy: str = "lazy",
|
||||
safetensors_load_strategy: str | None = None,
|
||||
local_expert_ids: set[int] | None = None,
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model safetensor files.
|
||||
@@ -812,7 +867,12 @@ def safetensors_weights_iterator(
|
||||
|
||||
sorted_files = sorted(hf_weights_files, key=_natural_sort_key)
|
||||
|
||||
if safetensors_load_strategy == "prefetch":
|
||||
should_prefetch = safetensors_load_strategy == "prefetch" or (
|
||||
safetensors_load_strategy is None
|
||||
and _is_nfs_path(sorted_files)
|
||||
and _checkpoints_fit_in_ram(sorted_files)
|
||||
)
|
||||
if should_prefetch:
|
||||
_prefetch_all_checkpoints(sorted_files)
|
||||
|
||||
leftover_state_dict: dict[str, torch.Tensor] = {}
|
||||
|
||||
Reference in New Issue
Block a user