[BugFix] Fix fastsafetensors TP all procs using all GPUs (#34070)

Signed-off-by: Nick Hill <nickhill123@gmail.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Nick Hill
2026-02-08 23:15:46 -08:00
committed by GitHub
parent 22b64948f6
commit d9bede0314

View File

@@ -801,8 +801,8 @@ def runai_safetensors_weights_iterator(
yield from tensor_iter
def _init_loader(
pg: torch.distributed.ProcessGroup,
def _init_fastsafetensors_loader(
pg: "torch.distributed.ProcessGroup",
device: torch.device,
f_list: list[str],
*,
@@ -825,13 +825,16 @@ def fastsafetensors_weights_iterator(
else:
pg = SingleGroup()
device = torch.device(f"cuda:{pg.rank()}")
device = torch.device(f"cuda:{current_platform.current_device()}")
weight_files_sub_lists = [
hf_weights_files[i : i + pg.size()]
for i in range(0, len(hf_weights_files), pg.size())
]
nogds = False
# Use nogds=True for TP > 1 to avoid cuFileDriverOpen() which
# initializes the GDS DMA subsystem for all visible GPUs, creating
# unwanted CUDA contexts on every device.
nogds = pg.size() > 1
for f_list in tqdm(
weight_files_sub_lists,
@@ -839,7 +842,7 @@ def fastsafetensors_weights_iterator(
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,
):
loader = _init_loader(pg, device, f_list, nogds=nogds)
loader = _init_fastsafetensors_loader(pg, device, f_list, nogds=nogds)
try:
try:
fb = loader.copy_files_to_device()
@@ -853,7 +856,7 @@ def fastsafetensors_weights_iterator(
"GDS not enabled, setting `nogds=True`.\n"
"For more information, see: https://github.com/foundation-model-stack/fastsafetensors?tab=readme-ov-file#basic-api-usages"
)
loader = _init_loader(pg, device, f_list, nogds=nogds)
loader = _init_fastsafetensors_loader(pg, device, f_list, nogds=nogds)
fb = loader.copy_files_to_device()
try: