Update deprecated type hinting in models (#18132)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-15 06:06:50 +01:00
committed by GitHub
parent 83f74c698f
commit 26d0419309
130 changed files with 971 additions and 901 deletions

View File

@@ -4,7 +4,7 @@ import math
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass
from functools import cached_property, partial
from typing import List, Optional, Set, Tuple, TypedDict, Union
from typing import Optional, TypedDict, Union
import numpy as np
import torch
@@ -90,7 +90,7 @@ class MolmoImageInputs(TypedDict):
@dataclass
class VisionBackboneConfig:
image_default_input_size: Tuple[int, int] = (336, 336)
image_default_input_size: tuple[int, int] = (336, 336)
image_patch_size: int = 14
image_pos_patch_size: int = 14
image_emb_dim: int = 1024
@@ -267,7 +267,7 @@ class BlockCollection(nn.Module):
for _ in range(config.image_num_layers)
])
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
hidden_states = []
for r in self.resblocks:
x = r(x)
@@ -334,7 +334,7 @@ class VisionTransformer(nn.Module):
def forward(self,
x: torch.Tensor,
patch_num: Optional[int] = None) -> List[torch.Tensor]:
patch_num: Optional[int] = None) -> list[torch.Tensor]:
"""
: param x: (batch_size, num_patch, n_pixels)
"""
@@ -434,7 +434,7 @@ class MolmoAttention(nn.Module):
)
def _apply_qk_norm(self, q: torch.Tensor,
k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
if self.tp_size > 1:
q = tensor_model_parallel_all_gather(q.contiguous())
k = tensor_model_parallel_all_gather(k.contiguous())
@@ -570,7 +570,7 @@ class MolmoDecoderLayer(nn.Module):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
# Self Attention
if residual is None:
residual = hidden_states
@@ -596,7 +596,7 @@ class MolmoDecoderNormAfterLayer(MolmoDecoderLayer):
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
# Self Attention
residual = hidden_states
hidden_states = self.self_attn(
@@ -740,15 +740,15 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
# image_features: (batch_size, num_image, num_patch, d_model)
return image_features
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("merged_linear", "gate_proj", 0),
("merged_linear", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
for (param_name, weight_name, shard_id) in stacked_params_mapping:
@@ -855,10 +855,10 @@ class MolmoModel(nn.Module, SupportsQuant):
hidden_states = self.norm(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if name.endswith(".bias") and name not in params_dict:
@@ -1530,7 +1530,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
sampling_metadata)
return logits
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(self)
weights = _get_weights_with_merged_embedding(weights)
@@ -1548,8 +1548,8 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
def _get_weights_with_merged_embedding(
weights: Iterable[Tuple[str, torch.Tensor]]
) -> Iterable[Tuple[str, torch.Tensor]]:
weights: Iterable[tuple[str, torch.Tensor]]
) -> Iterable[tuple[str, torch.Tensor]]:
embedding_weights = {}
for name, weight in weights:
if "wte.embedding" in name: