[Misc/Testing] Use torch.testing.assert_close (#7324)

This commit is contained in:
jon-chuang
2024-08-15 21:24:04 -07:00
committed by GitHub
parent e165528778
commit 50b8d08dbd
25 changed files with 197 additions and 188 deletions

View File

@@ -34,7 +34,7 @@ def all_reduce_test_worker(tp_size: int, pp_size: int, rank: int,
expected = torch.sum(torch.stack(all_tensors, dim=0), dim=0)
t = all_tensors[rank % tp_size]
t = tensor_model_parallel_all_reduce(t)
assert torch.allclose(t, expected)
torch.testing.assert_close(t, expected)
@ray.remote(num_gpus=1, max_calls=1)
@@ -62,7 +62,7 @@ def all_gather_test_worker(tp_size: int, pp_size: int, rank: int,
expected = torch.cat(all_tensors, dim=all_gather_dimension)
t = all_tensors[rank % tp_size]
t = tensor_model_parallel_all_gather(t, all_gather_dimension)
assert torch.allclose(t, expected)
torch.testing.assert_close(t, expected)
@ray.remote(num_gpus=1, max_calls=1)
@@ -96,12 +96,12 @@ def broadcast_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
else:
recv_dict = broadcast_tensor_dict(src=0)
assert len(recv_dict) == len(test_dict)
assert torch.allclose(recv_dict["a"], test_dict["a"])
assert torch.allclose(recv_dict["b"], test_dict["b"])
torch.testing.assert_close(recv_dict["a"], test_dict["a"])
torch.testing.assert_close(recv_dict["b"], test_dict["b"])
assert recv_dict["c"] == test_dict["c"]
assert recv_dict["d"] == test_dict["d"]
assert recv_dict["e"] == test_dict["e"]
assert torch.allclose(recv_dict["f"], test_dict["f"])
torch.testing.assert_close(recv_dict["f"], test_dict["f"])
@ray.remote(num_gpus=1, max_calls=1)
@@ -136,12 +136,12 @@ def send_recv_tensor_dict_test_worker(tp_size: int, pp_size: int, rank: int,
if not get_pp_group().is_first_rank:
assert len(recv_dict) == len(test_dict)
assert torch.allclose(recv_dict["a"], test_dict["a"])
assert torch.allclose(recv_dict["b"], test_dict["b"])
torch.testing.assert_close(recv_dict["a"], test_dict["a"])
torch.testing.assert_close(recv_dict["b"], test_dict["b"])
assert recv_dict["c"] == test_dict["c"]
assert recv_dict["d"] == test_dict["d"]
assert recv_dict["e"] == test_dict["e"]
assert torch.allclose(recv_dict["f"], test_dict["f"])
torch.testing.assert_close(recv_dict["f"], test_dict["f"])
@ray.remote(num_gpus=1, max_calls=1)
@@ -163,7 +163,7 @@ def send_recv_test_worker(tp_size: int, pp_size: int, rank: int,
get_pp_group().send(test_tensor)
if not get_pp_group().is_first_rank:
assert torch.allclose(test_tensor, recv_tensor)
torch.testing.assert_close(test_tensor, recv_tensor)
@pytest.mark.skipif(torch.cuda.device_count() < 2,

View File

@@ -72,8 +72,8 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
out2 = tensor_model_parallel_all_reduce(inp2)
dist.all_reduce(inp2, group=group)
graph.replay()
assert torch.allclose(out1, inp1)
assert torch.allclose(out2, inp2)
torch.testing.assert_close(out1, inp1)
torch.testing.assert_close(out2, inp2)
@ray.remote(num_gpus=1, max_calls=1)
@@ -96,13 +96,13 @@ def eager_allreduce(tp_size, pp_size, rank, distributed_init_port):
out = inp
for _ in range(num_communication):
out = fa.all_reduce_unreg(out)
assert torch.allclose(out, inp * (tp_size**num_communication))
torch.testing.assert_close(out, inp * (tp_size**num_communication))
inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device)
out = inp
for _ in range(num_communication):
out = fa.all_reduce_unreg(out)
assert torch.allclose(out, inp * (tp_size**num_communication))
torch.testing.assert_close(out, inp * (tp_size**num_communication))
@pytest.mark.parametrize("tp_size", [2])