Fix gpt oss weight loading with EP + bf16 (#28765)
Signed-off-by: ashors1 <ashors@nvidia.com>
This commit is contained in:
@@ -494,8 +494,8 @@ class GptOssModel(nn.Module):
|
||||
|
||||
def _load_weights_other(
|
||||
self,
|
||||
ep_rank_start: int,
|
||||
ep_rank_end: int,
|
||||
ep_rank_start: int,
|
||||
heads_per_rank: int,
|
||||
head_start: int,
|
||||
weights: Iterable[tuple[str, torch.Tensor]],
|
||||
|
||||
Reference in New Issue
Block a user