Update deprecated type hinting in models (#18132)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-15 06:06:50 +01:00
committed by GitHub
parent 83f74c698f
commit 26d0419309
130 changed files with 971 additions and 901 deletions

View File

@@ -2,7 +2,7 @@
from collections.abc import Iterable, Mapping, Sequence
from functools import cached_property
from typing import Any, Dict, Literal, Optional, Set, Tuple, TypedDict, Union
from typing import Any, Literal, Optional, TypedDict, Union
import torch
import torch.nn as nn
@@ -229,7 +229,7 @@ class ChameleonAttention(nn.Module):
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
rope_scaling: Optional[dict[str, Any]] = None,
max_position_embeddings: int = 4096,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
@@ -292,7 +292,7 @@ class ChameleonAttention(nn.Module):
prefix=f"{prefix}.attn")
def _apply_qk_norm(self, q: torch.Tensor,
k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# reshape for layernorm
q = q.reshape(-1, self.num_heads, self.head_dim)
k = k.reshape(-1, self.num_kv_heads, self.head_dim)
@@ -367,7 +367,7 @@ class ChameleonDecoderLayer(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
if residual is None:
residual = hidden_states
@@ -438,7 +438,7 @@ class ChameleonSwinDecoderLayer(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor]:
residual = hidden_states
hidden_states = self.self_attn(
@@ -773,7 +773,7 @@ class ChameleonVQVAE(nn.Module):
def encode(
self, pixel_values: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
hidden_states = self.encoder(pixel_values)
hidden_states = self.quant_conv(hidden_states)
quant, emb_loss, indices = self.quantize(hidden_states)
@@ -786,7 +786,7 @@ class ChameleonImageVocabularyMapping:
A class for mapping discrete image tokens from VQGAN to BPE tokens.
"""
def __init__(self, vocab_map: Dict[str, int]):
def __init__(self, vocab_map: dict[str, int]):
self.vocab_map = vocab_map
self.image_token_id = vocab_map.get("<image>")
@@ -1052,8 +1052,8 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
return logits
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
@@ -1063,7 +1063,7 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
(".gate_up_proj", ".up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue