[BugFix] Fix fastsafetensors TP all procs using all GPUs (#34070)
Signed-off-by: Nick Hill <nickhill123@gmail.com> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -801,8 +801,8 @@ def runai_safetensors_weights_iterator(
|
||||
yield from tensor_iter
|
||||
|
||||
|
||||
def _init_loader(
|
||||
pg: torch.distributed.ProcessGroup,
|
||||
def _init_fastsafetensors_loader(
|
||||
pg: "torch.distributed.ProcessGroup",
|
||||
device: torch.device,
|
||||
f_list: list[str],
|
||||
*,
|
||||
@@ -825,13 +825,16 @@ def fastsafetensors_weights_iterator(
|
||||
else:
|
||||
pg = SingleGroup()
|
||||
|
||||
device = torch.device(f"cuda:{pg.rank()}")
|
||||
device = torch.device(f"cuda:{current_platform.current_device()}")
|
||||
weight_files_sub_lists = [
|
||||
hf_weights_files[i : i + pg.size()]
|
||||
for i in range(0, len(hf_weights_files), pg.size())
|
||||
]
|
||||
|
||||
nogds = False
|
||||
# Use nogds=True for TP > 1 to avoid cuFileDriverOpen() which
|
||||
# initializes the GDS DMA subsystem for all visible GPUs, creating
|
||||
# unwanted CUDA contexts on every device.
|
||||
nogds = pg.size() > 1
|
||||
|
||||
for f_list in tqdm(
|
||||
weight_files_sub_lists,
|
||||
@@ -839,7 +842,7 @@ def fastsafetensors_weights_iterator(
|
||||
disable=not enable_tqdm(use_tqdm_on_load),
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
loader = _init_loader(pg, device, f_list, nogds=nogds)
|
||||
loader = _init_fastsafetensors_loader(pg, device, f_list, nogds=nogds)
|
||||
try:
|
||||
try:
|
||||
fb = loader.copy_files_to_device()
|
||||
@@ -853,7 +856,7 @@ def fastsafetensors_weights_iterator(
|
||||
"GDS not enabled, setting `nogds=True`.\n"
|
||||
"For more information, see: https://github.com/foundation-model-stack/fastsafetensors?tab=readme-ov-file#basic-api-usages"
|
||||
)
|
||||
loader = _init_loader(pg, device, f_list, nogds=nogds)
|
||||
loader = _init_fastsafetensors_loader(pg, device, f_list, nogds=nogds)
|
||||
fb = loader.copy_files_to_device()
|
||||
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user