Fix gpt oss weight loading with EP + bf16 (#28765)
Signed-off-by: ashors1 <ashors@nvidia.com>
This commit is contained in:
@@ -494,8 +494,8 @@ class GptOssModel(nn.Module):
|
|||||||
|
|
||||||
def _load_weights_other(
|
def _load_weights_other(
|
||||||
self,
|
self,
|
||||||
ep_rank_start: int,
|
|
||||||
ep_rank_end: int,
|
ep_rank_end: int,
|
||||||
|
ep_rank_start: int,
|
||||||
heads_per_rank: int,
|
heads_per_rank: int,
|
||||||
head_start: int,
|
head_start: int,
|
||||||
weights: Iterable[tuple[str, torch.Tensor]],
|
weights: Iterable[tuple[str, torch.Tensor]],
|
||||||
|
|||||||
Reference in New Issue
Block a user