Update deprecated type hinting in models (#18132)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from functools import cached_property
|
||||
from typing import Any, Dict, Literal, Optional, Set, Tuple, TypedDict, Union
|
||||
from typing import Any, Literal, Optional, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -229,7 +229,7 @@ class ChameleonAttention(nn.Module):
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
rope_scaling: Optional[dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 4096,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
bias: bool = False,
|
||||
@@ -292,7 +292,7 @@ class ChameleonAttention(nn.Module):
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def _apply_qk_norm(self, q: torch.Tensor,
|
||||
k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# reshape for layernorm
|
||||
q = q.reshape(-1, self.num_heads, self.head_dim)
|
||||
k = k.reshape(-1, self.num_kv_heads, self.head_dim)
|
||||
@@ -367,7 +367,7 @@ class ChameleonDecoderLayer(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
@@ -438,7 +438,7 @@ class ChameleonSwinDecoderLayer(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.self_attn(
|
||||
@@ -773,7 +773,7 @@ class ChameleonVQVAE(nn.Module):
|
||||
|
||||
def encode(
|
||||
self, pixel_values: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
hidden_states = self.encoder(pixel_values)
|
||||
hidden_states = self.quant_conv(hidden_states)
|
||||
quant, emb_loss, indices = self.quantize(hidden_states)
|
||||
@@ -786,7 +786,7 @@ class ChameleonImageVocabularyMapping:
|
||||
A class for mapping discrete image tokens from VQGAN to BPE tokens.
|
||||
"""
|
||||
|
||||
def __init__(self, vocab_map: Dict[str, int]):
|
||||
def __init__(self, vocab_map: dict[str, int]):
|
||||
self.vocab_map = vocab_map
|
||||
self.image_token_id = vocab_map.get("<image>")
|
||||
|
||||
@@ -1052,8 +1052,8 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
@@ -1063,7 +1063,7 @@ class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
loaded_params: set[str] = set()
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
Reference in New Issue
Block a user