[RLHF] Fix torch.dtype not serializable in example (#22158)

Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
22quinn
2025-08-03 19:43:33 -07:00
committed by GitHub
parent e27d25a0dc
commit 845420ac2c
2 changed files with 6 additions and 2 deletions

View File

@@ -45,7 +45,8 @@ class WorkerExtension:
self.device,
)
def update_weight(self, name, dtype, shape):
def update_weight(self, name, dtype_name, shape):
dtype = getattr(torch, dtype_name)
weight = torch.empty(shape, dtype=dtype, device="cuda")
self.model_update_group.broadcast(
weight, src=0, stream=torch.cuda.current_stream()