Update deprecated type hinting in models (#18132)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-15 06:06:50 +01:00
committed by GitHub
parent 83f74c698f
commit 26d0419309
130 changed files with 971 additions and 901 deletions

View File

@@ -18,7 +18,7 @@
import math
from collections.abc import Iterable, Mapping
from itertools import tee
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
from typing import Literal, Optional, TypedDict, Union
import torch
from torch import nn
@@ -582,7 +582,7 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> List[PromptUpdate]:
) -> list[PromptUpdate]:
assert (
mm_items.get_count("image", strict=False) == 0
or "aspect_ratios" in out_mm_kwargs
@@ -778,26 +778,26 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
def separate_weights(
self,
weights: Iterable[Tuple[str, torch.Tensor]],
weights: Iterable[tuple[str, torch.Tensor]],
prefix: str,
) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[
) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[
str, torch.Tensor]]]:
weights1, weights2 = tee(weights, 2)
def get_prefix_weights() -> Iterable[Tuple[str, torch.Tensor]]:
def get_prefix_weights() -> Iterable[tuple[str, torch.Tensor]]:
for name, data in weights1:
if name.startswith(prefix):
yield (name, data)
def get_other_weights() -> Iterable[Tuple[str, torch.Tensor]]:
def get_other_weights() -> Iterable[tuple[str, torch.Tensor]]:
for name, data in weights2:
if not name.startswith(prefix):
yield (name, data)
return get_prefix_weights(), get_other_weights()
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
@@ -806,7 +806,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
]
params_dict = dict(self.named_parameters())
updated_params: Set[str] = set()
updated_params: set[str] = set()
# language_model is an Llama4ForCausalLM instance. We load it's
# using llama4's load_weights routine.