Update deprecated type hinting in model_executor/layers (#18056)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-13 12:17:23 +01:00
committed by GitHub
parent 906f0598fc
commit 6223dd8114
87 changed files with 523 additions and 523 deletions

View File

@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
from typing import List, Optional, Tuple, Union
from typing import Optional, Union
import torch
from torch import nn
@@ -104,7 +104,7 @@ class Mixer2RMSNormGated(CustomOp):
self,
x: torch.Tensor,
gate: torch.Tensor,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if self.tp_size > 1 or self.n_groups != 1:
return self.forward_native(x, gate)
@@ -136,7 +136,7 @@ def extra_groups_for_head_shards(ngroups: int, tp_size: int):
def mamba_v2_sharded_weight_loader(
shard_spec: List[Tuple[int, int, float]],
shard_spec: list[tuple[int, int, float]],
tp_size: int,
tp_rank: int,
) -> LoaderFunction: