[Distributed] Add custom allreduce support for ROCM (#14125)

Signed-off-by: ilmarkov <imarkov@redhat.com>
Co-authored-by: ilmarkov <imarkov@redhat.com>
This commit is contained in:
Ilya Markov
2025-04-01 07:49:12 +02:00
committed by GitHub
parent e6e3c55ef2
commit b7b7676d67
13 changed files with 373 additions and 160 deletions

View File

@@ -106,7 +106,7 @@ def eager_allreduce(
# communicate independently
num_communication = rank // tp_size + 1
sz = 1024
fa = get_tp_group().ca_comm
fa = get_tp_group().device_communicator.ca_comm
inp = torch.ones(sz, dtype=torch.float32, device=device)
out = inp
for _ in range(num_communication):