[Core][Test] fix function name typo in custom allreduce (#4750)
This commit is contained in:
@@ -25,7 +25,7 @@ def graph_allreduce(world_size, rank, distributed_init_port):
|
||||
init_test_distributed_environment(1, world_size, rank,
|
||||
distributed_init_port)
|
||||
|
||||
custom_all_reduce.init_custom_all_reduce()
|
||||
custom_all_reduce.init_custom_ar()
|
||||
for sz in test_sizes:
|
||||
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
|
||||
with custom_all_reduce.capture():
|
||||
@@ -61,7 +61,7 @@ def eager_allreduce(world_size, rank, distributed_init_port):
|
||||
distributed_init_port)
|
||||
|
||||
sz = 1024
|
||||
custom_all_reduce.init_custom_all_reduce()
|
||||
custom_all_reduce.init_custom_ar()
|
||||
fa = custom_all_reduce.get_handle()
|
||||
inp = torch.ones(sz, dtype=torch.float32, device=device)
|
||||
out = fa.all_reduce_unreg(inp)
|
||||
|
||||
Reference in New Issue
Block a user