Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -26,13 +26,13 @@ def distributed_run(fn, world_size):
|
||||
processes = []
|
||||
for i in range(number_of_processes):
|
||||
env = {}
|
||||
env['RANK'] = str(i)
|
||||
env['LOCAL_RANK'] = str(i)
|
||||
env['WORLD_SIZE'] = str(number_of_processes)
|
||||
env['LOCAL_WORLD_SIZE'] = str(number_of_processes)
|
||||
env['MASTER_ADDR'] = 'localhost'
|
||||
env['MASTER_PORT'] = '12345'
|
||||
p = multiprocessing.Process(target=fn, args=(env, ))
|
||||
env["RANK"] = str(i)
|
||||
env["LOCAL_RANK"] = str(i)
|
||||
env["WORLD_SIZE"] = str(number_of_processes)
|
||||
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
|
||||
env["MASTER_ADDR"] = "localhost"
|
||||
env["MASTER_PORT"] = "12345"
|
||||
p = multiprocessing.Process(target=fn, args=(env,))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
@@ -57,25 +57,23 @@ def worker_fn_wrapper(fn):
|
||||
|
||||
@worker_fn_wrapper
|
||||
def worker_fn():
|
||||
|
||||
rank = dist.get_rank()
|
||||
if rank == 0:
|
||||
port = get_open_port()
|
||||
ip = '127.0.0.1'
|
||||
ip = "127.0.0.1"
|
||||
dist.broadcast_object_list([ip, port], src=0)
|
||||
else:
|
||||
recv = [None, None]
|
||||
dist.broadcast_object_list(recv, src=0)
|
||||
ip, port = recv # type: ignore
|
||||
|
||||
stateless_pg = StatelessProcessGroup.create(ip, port, rank,
|
||||
dist.get_world_size())
|
||||
stateless_pg = StatelessProcessGroup.create(ip, port, rank, dist.get_world_size())
|
||||
|
||||
for pg in [dist.group.WORLD, stateless_pg]:
|
||||
|
||||
writer_rank = 2
|
||||
broadcaster = MessageQueue.create_from_process_group(
|
||||
pg, 40 * 1024, 2, writer_rank)
|
||||
pg, 40 * 1024, 2, writer_rank
|
||||
)
|
||||
if rank == writer_rank:
|
||||
seed = random.randint(0, 1000)
|
||||
dist.broadcast_object_list([seed], writer_rank)
|
||||
|
||||
Reference in New Issue
Block a user