[BugFix] Fix torch distributed stateless PG backend init (#14870)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -299,13 +299,10 @@ def stateless_init_torch_distributed_process_group(
|
||||
# different systems (e.g. RPC) in case the store is multi-tenant.
|
||||
prefix_store = PrefixStore(init_method, store)
|
||||
|
||||
pg_options = ProcessGroup.Options(backend=backend, timeout=timeout)
|
||||
|
||||
pg: ProcessGroup = ProcessGroup(
|
||||
prefix_store,
|
||||
group_rank,
|
||||
group_size,
|
||||
pg_options,
|
||||
)
|
||||
|
||||
if backend == "gloo":
|
||||
@@ -327,7 +324,10 @@ def stateless_init_torch_distributed_process_group(
|
||||
backend_options)
|
||||
backend_type = ProcessGroup.BackendType.NCCL
|
||||
device = torch.device("cuda")
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported torch distributed backend: {backend}")
|
||||
|
||||
pg._set_default_backend(backend_type)
|
||||
backend_class._set_sequence_number_for_group()
|
||||
|
||||
pg._register_backend(device, backend_type, backend_class)
|
||||
|
||||
Reference in New Issue
Block a user