[Misc/Testing] Use torch.testing.assert_close (#7324)
This commit is contained in:
@@ -72,8 +72,8 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
|
||||
out2 = tensor_model_parallel_all_reduce(inp2)
|
||||
dist.all_reduce(inp2, group=group)
|
||||
graph.replay()
|
||||
assert torch.allclose(out1, inp1)
|
||||
assert torch.allclose(out2, inp2)
|
||||
torch.testing.assert_close(out1, inp1)
|
||||
torch.testing.assert_close(out2, inp2)
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
@@ -96,13 +96,13 @@ def eager_allreduce(tp_size, pp_size, rank, distributed_init_port):
|
||||
out = inp
|
||||
for _ in range(num_communication):
|
||||
out = fa.all_reduce_unreg(out)
|
||||
assert torch.allclose(out, inp * (tp_size**num_communication))
|
||||
torch.testing.assert_close(out, inp * (tp_size**num_communication))
|
||||
|
||||
inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device)
|
||||
out = inp
|
||||
for _ in range(num_communication):
|
||||
out = fa.all_reduce_unreg(out)
|
||||
assert torch.allclose(out, inp * (tp_size**num_communication))
|
||||
torch.testing.assert_close(out, inp * (tp_size**num_communication))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tp_size", [2])
|
||||
|
||||
Reference in New Issue
Block a user