[Core][Distributed] Refactor ipc buffer init in CustomAllreduce (#10030)
Signed-off-by: Hanzhi Zhou <hanzhi713@gmail.com>
This commit is contained in:
@@ -95,13 +95,13 @@ def eager_allreduce(tp_size, pp_size, rank, distributed_init_port):
|
||||
inp = torch.ones(sz, dtype=torch.float32, device=device)
|
||||
out = inp
|
||||
for _ in range(num_communication):
|
||||
out = fa.all_reduce_unreg(out)
|
||||
out = fa.all_reduce(out, registered=False)
|
||||
torch.testing.assert_close(out, inp * (tp_size**num_communication))
|
||||
|
||||
inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device)
|
||||
out = inp
|
||||
for _ in range(num_communication):
|
||||
out = fa.all_reduce_unreg(out)
|
||||
out = fa.all_reduce(out, registered=False)
|
||||
torch.testing.assert_close(out, inp * (tp_size**num_communication))
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user