Fix gpt oss weight loading with EP + bf16 (#28765)

Signed-off-by: ashors1 <ashors@nvidia.com>
This commit is contained in:
Anna Shors
2025-11-16 05:12:45 -08:00
committed by GitHub
parent 3bc1175798
commit 8d259fad6c

View File

@@ -494,8 +494,8 @@ class GptOssModel(nn.Module):
def _load_weights_other( def _load_weights_other(
self, self,
ep_rank_start: int,
ep_rank_end: int, ep_rank_end: int,
ep_rank_start: int,
heads_per_rank: int, heads_per_rank: int,
head_start: int, head_start: int,
weights: Iterable[tuple[str, torch.Tensor]], weights: Iterable[tuple[str, torch.Tensor]],