Update deprecated type hinting in models (#18132)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-15 06:06:50 +01:00
committed by GitHub
parent 83f74c698f
commit 26d0419309
130 changed files with 971 additions and 901 deletions

View File

@@ -16,7 +16,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only LLaMA model compatible with HuggingFace weights."""
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
from collections.abc import Iterable
from typing import Any, Optional
import torch
from torch import nn
@@ -48,7 +49,7 @@ class Llama4MoE(nn.Module):
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
router_scores, router_indices = fast_topk(gating_output, topk, dim=-1)
# psuedo-standard is that the router scores are floats
router_scores = torch.sigmoid(router_scores.float())
@@ -115,7 +116,7 @@ class Llama4Attention(nn.Module):
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
@@ -300,7 +301,7 @@ class Llama4DecoderLayer(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
# Self Attention
if residual is None:
residual = hidden_states
@@ -335,9 +336,9 @@ class Llama4Model(LlamaModel):
self,
name: str,
loaded_weight: torch.Tensor,
params_dict: Dict[str, nn.Parameter],
loaded_params: Set[str],
expert_params_mapping: List[Tuple[str, str, int, str]],
params_dict: dict[str, nn.Parameter],
loaded_params: set[str],
expert_params_mapping: list[tuple[str, str, int, str]],
fused: bool = True,
) -> bool:
expert_param_loaded = False
@@ -390,8 +391,8 @@ class Llama4Model(LlamaModel):
expert_param_loaded = True
return expert_param_loaded
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
@@ -412,7 +413,7 @@ class Llama4Model(LlamaModel):
ckpt_up_proj_name="gate_up_proj",
num_experts=1)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "experts.gate_up_proj" in name or "experts.down_proj" in name:
fused_experts_params = True
@@ -489,8 +490,8 @@ class Llama4ForCausalLM(LlamaForCausalLM):
prefix=prefix,
layer_type=layer_type)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(
self,
skip_prefixes=(["lm_head."]
@@ -506,7 +507,7 @@ class Llama4ForCausalLM(LlamaForCausalLM):
self,
name: str,
loaded_weight: torch.Tensor,
) -> Tuple[str, torch.Tensor]:
) -> tuple[str, torch.Tensor]:
def permute(w: torch.Tensor, n_heads: int):
attn_in = self.config.head_dim * n_heads