Update deprecated type hinting in models (#18132)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-15 06:06:50 +01:00
committed by GitHub
parent 83f74c698f
commit 26d0419309
130 changed files with 971 additions and 901 deletions

View File

@@ -3,7 +3,8 @@
import copy
import math
import re
from typing import Dict, Iterable, List, Optional, Set, Tuple, Union
from collections.abc import Iterable
from typing import Optional, Union
import torch
import torch.distributed
@@ -127,7 +128,7 @@ class MiniMaxText01RMSNormTP(CustomOp):
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert residual is None, "RMSNorm does not support residual connection."
return self._forward(x)
@@ -178,7 +179,7 @@ class MiniMaxText01RotaryEmbedding(CustomOp):
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops
self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
query_cast = query.to(self.cache_dtype)
@@ -708,11 +709,11 @@ class MiniMaxText01DecoderLayer(nn.Module):
def forward(self,
hidden_states: torch.Tensor,
positions: torch.Tensor,
kv_caches: Union[List[Dict], Optional[torch.Tensor]],
kv_caches: Union[list[dict], Optional[torch.Tensor]],
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor],
is_warmup: bool = False,
**kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
**kwargs) -> tuple[torch.Tensor, torch.Tensor]:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
@@ -1072,10 +1073,10 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid,
device=device),
})
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
def which_layer(name: str) -> int:
if "layers" in name: