[Misc/Testing] Use torch.testing.assert_close (#7324)
This commit is contained in:
@@ -34,7 +34,7 @@ def all_reduce_test_worker(tp_size: int, pp_size: int, rank: int,
|
||||
expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
|
||||
t = all_tensors[rank % tp_size]
|
||||
t = tensor_model_parallel_all_reduce(t)
|
||||
assert torch.allclose(t, expected)
|
||||
torch.testing.assert_close(t, expected)
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
@@ -62,7 +62,7 @@ def all_gather_test_worker(tp_size: int, pp_size: int, rank: int,
|
||||
expected = torch.cat(all_tensors, dim=all_gather_dimension)
|
||||
t = all_tensors[rank % tp_size]
|
||||
t = tensor_model_parallel_all_gather(t, all_gather_dimension)
|
||||
assert torch.allclose(t, expected)
|
||||
torch.testing.assert_close(t, expected)
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
@@ -96,12 +96,12 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
|
||||
else:
|
||||
recv_dict = broadcast_tensor_dict(src=0)
|
||||
assert len(recv_dict) == len(test_dict)
|
||||
assert torch.allclose(recv_dict["a"], test_dict["a"])
|
||||
assert torch.allclose(recv_dict["b"], test_dict["b"])
|
||||
torch.testing.assert_close(recv_dict["a"], test_dict["a"])
|
||||
torch.testing.assert_close(recv_dict["b"], test_dict["b"])
|
||||
assert recv_dict["c"] == test_dict["c"]
|
||||
assert recv_dict["d"] == test_dict["d"]
|
||||
assert recv_dict["e"] == test_dict["e"]
|
||||
assert torch.allclose(recv_dict["f"], test_dict["f"])
|
||||
torch.testing.assert_close(recv_dict["f"], test_dict["f"])
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
@@ -136,12 +136,12 @@ def send_recv_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
|
||||
|
||||
if not get_pp_group().is_first_rank:
|
||||
assert len(recv_dict) == len(test_dict)
|
||||
assert torch.allclose(recv_dict["a"], test_dict["a"])
|
||||
assert torch.allclose(recv_dict["b"], test_dict["b"])
|
||||
torch.testing.assert_close(recv_dict["a"], test_dict["a"])
|
||||
torch.testing.assert_close(recv_dict["b"], test_dict["b"])
|
||||
assert recv_dict["c"] == test_dict["c"]
|
||||
assert recv_dict["d"] == test_dict["d"]
|
||||
assert recv_dict["e"] == test_dict["e"]
|
||||
assert torch.allclose(recv_dict["f"], test_dict["f"])
|
||||
torch.testing.assert_close(recv_dict["f"], test_dict["f"])
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
@@ -163,7 +163,7 @@ def send_recv_test_worker(tp_size: int, pp_size: int, rank: int,
|
||||
get_pp_group().send(test_tensor)
|
||||
|
||||
if not get_pp_group().is_first_rank:
|
||||
assert torch.allclose(test_tensor, recv_tensor)
|
||||
torch.testing.assert_close(test_tensor, recv_tensor)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
|
||||
@@ -72,8 +72,8 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
|
||||
out2 = tensor_model_parallel_all_reduce(inp2)
|
||||
dist.all_reduce(inp2, group=group)
|
||||
graph.replay()
|
||||
assert torch.allclose(out1, inp1)
|
||||
assert torch.allclose(out2, inp2)
|
||||
torch.testing.assert_close(out1, inp1)
|
||||
torch.testing.assert_close(out2, inp2)
|
||||
|
||||
|
||||
@ray.remote(num_gpus=1, max_calls=1)
|
||||
@@ -96,13 +96,13 @@ def eager_allreduce(tp_size, pp_size, rank, distributed_init_port):
|
||||
out = inp
|
||||
for _ in range(num_communication):
|
||||
out = fa.all_reduce_unreg(out)
|
||||
assert torch.allclose(out, inp * (tp_size**num_communication))
|
||||
torch.testing.assert_close(out, inp * (tp_size**num_communication))
|
||||
|
||||
inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device)
|
||||
out = inp
|
||||
for _ in range(num_communication):
|
||||
out = fa.all_reduce_unreg(out)
|
||||
assert torch.allclose(out, inp * (tp_size**num_communication))
|
||||
torch.testing.assert_close(out, inp * (tp_size**num_communication))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tp_size", [2])
|
||||
|
||||
Reference in New Issue
Block a user