Update deprecated type hinting in models (#18132)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -4,7 +4,7 @@ import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property, partial
|
||||
from typing import List, Optional, Set, Tuple, TypedDict, Union
|
||||
from typing import Optional, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -90,7 +90,7 @@ class MolmoImageInputs(TypedDict):
|
||||
|
||||
@dataclass
|
||||
class VisionBackboneConfig:
|
||||
image_default_input_size: Tuple[int, int] = (336, 336)
|
||||
image_default_input_size: tuple[int, int] = (336, 336)
|
||||
image_patch_size: int = 14
|
||||
image_pos_patch_size: int = 14
|
||||
image_emb_dim: int = 1024
|
||||
@@ -267,7 +267,7 @@ class BlockCollection(nn.Module):
|
||||
for _ in range(config.image_num_layers)
|
||||
])
|
||||
|
||||
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
|
||||
hidden_states = []
|
||||
for r in self.resblocks:
|
||||
x = r(x)
|
||||
@@ -334,7 +334,7 @@ class VisionTransformer(nn.Module):
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
patch_num: Optional[int] = None) -> List[torch.Tensor]:
|
||||
patch_num: Optional[int] = None) -> list[torch.Tensor]:
|
||||
"""
|
||||
: param x: (batch_size, num_patch, n_pixels)
|
||||
"""
|
||||
@@ -434,7 +434,7 @@ class MolmoAttention(nn.Module):
|
||||
)
|
||||
|
||||
def _apply_qk_norm(self, q: torch.Tensor,
|
||||
k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.tp_size > 1:
|
||||
q = tensor_model_parallel_all_gather(q.contiguous())
|
||||
k = tensor_model_parallel_all_gather(k.contiguous())
|
||||
@@ -570,7 +570,7 @@ class MolmoDecoderLayer(nn.Module):
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
@@ -596,7 +596,7 @@ class MolmoDecoderNormAfterLayer(MolmoDecoderLayer):
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
||||
) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
hidden_states = self.self_attn(
|
||||
@@ -740,15 +740,15 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
|
||||
# image_features: (batch_size, num_image, num_patch, d_model)
|
||||
return image_features
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("merged_linear", "gate_proj", 0),
|
||||
("merged_linear", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
@@ -855,10 +855,10 @@ class MolmoModel(nn.Module, SupportsQuant):
|
||||
hidden_states = self.norm(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
loaded_params: set[str] = set()
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
@@ -1530,7 +1530,7 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
|
||||
loader = AutoWeightsLoader(self)
|
||||
weights = _get_weights_with_merged_embedding(weights)
|
||||
@@ -1548,8 +1548,8 @@ class MolmoForCausalLM(nn.Module, SupportsMultiModal, SupportsPP, SupportsLoRA,
|
||||
|
||||
|
||||
def _get_weights_with_merged_embedding(
|
||||
weights: Iterable[Tuple[str, torch.Tensor]]
|
||||
) -> Iterable[Tuple[str, torch.Tensor]]:
|
||||
weights: Iterable[tuple[str, torch.Tensor]]
|
||||
) -> Iterable[tuple[str, torch.Tensor]]:
|
||||
embedding_weights = {}
|
||||
for name, weight in weights:
|
||||
if "wte.embedding" in name:
|
||||
|
||||
Reference in New Issue
Block a user