[Core][Bugfix] Use correct device to initialize GPU data during CUDA-graph-capture (#11233)

Signed-off-by: Yan Burman <yanburman@users.noreply.github.com>
Signed-off-by: Ido Asraff <idoa@atero.ai>
This commit is contained in:
Yan Burman
2025-01-04 08:50:16 +02:00
committed by GitHub
parent d91457d529
commit 300acb8347
5 changed files with 23 additions and 15 deletions

View File

@@ -50,7 +50,7 @@ def graph_allreduce(tp_size, pp_size, rank, distributed_init_port):
for sz in test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
with graph_capture() as graph_capture_context:
with graph_capture(device=device) as graph_capture_context:
# use integers so result matches NCCL exactly
inp1 = torch.randint(1,
16, (sz, ),