[Core][Distributed] support cpu&device in broadcast tensor dict (#4660)
[Core][Distributed] support both cpu and device tensor in broadcast tensor dict (#4660)
This commit is contained in:
@@ -77,14 +77,18 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
|
||||
init_test_distributed_environment(1, tensor_parallel_size, rank,
|
||||
distributed_init_port)
|
||||
test_dict = {
|
||||
# device tensor
|
||||
"a": torch.arange(8, dtype=torch.float32, device="cuda"),
|
||||
"b": torch.arange(16, dtype=torch.int8, device="cuda"),
|
||||
# CPU tensor
|
||||
"b": torch.arange(16, dtype=torch.int8, device="cpu"),
|
||||
"c": "test",
|
||||
"d": [1, 2, 3],
|
||||
"e": {
|
||||
"a": 1,
|
||||
"b": 2
|
||||
},
|
||||
# empty tensor
|
||||
"f": torch.tensor([], dtype=torch.float32, device="cuda"),
|
||||
}
|
||||
|
||||
if rank == 0:
|
||||
@@ -97,6 +101,7 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
|
||||
assert recv_dict["c"] == test_dict["c"]
|
||||
assert recv_dict["d"] == test_dict["d"]
|
||||
assert recv_dict["e"] == test_dict["e"]
|
||||
assert torch.allclose(recv_dict["f"], test_dict["f"])
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
|
||||
Reference in New Issue
Block a user