Update deprecated type hinting in models (#18132)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -18,7 +18,7 @@
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping
|
||||
from itertools import tee
|
||||
from typing import List, Literal, Optional, Set, Tuple, TypedDict, Union
|
||||
from typing import Literal, Optional, TypedDict, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -582,7 +582,7 @@ class Mllama4MultiModalProcessor(BaseMultiModalProcessor[Mllama4ProcessingInfo]
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> List[PromptUpdate]:
|
||||
) -> list[PromptUpdate]:
|
||||
assert (
|
||||
mm_items.get_count("image", strict=False) == 0
|
||||
or "aspect_ratios" in out_mm_kwargs
|
||||
@@ -778,26 +778,26 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
|
||||
def separate_weights(
|
||||
self,
|
||||
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
weights: Iterable[tuple[str, torch.Tensor]],
|
||||
prefix: str,
|
||||
) -> Tuple[Iterable[Tuple[str, torch.Tensor]], Iterable[Tuple[
|
||||
) -> tuple[Iterable[tuple[str, torch.Tensor]], Iterable[tuple[
|
||||
str, torch.Tensor]]]:
|
||||
weights1, weights2 = tee(weights, 2)
|
||||
|
||||
def get_prefix_weights() -> Iterable[Tuple[str, torch.Tensor]]:
|
||||
def get_prefix_weights() -> Iterable[tuple[str, torch.Tensor]]:
|
||||
for name, data in weights1:
|
||||
if name.startswith(prefix):
|
||||
yield (name, data)
|
||||
|
||||
def get_other_weights() -> Iterable[Tuple[str, torch.Tensor]]:
|
||||
def get_other_weights() -> Iterable[tuple[str, torch.Tensor]]:
|
||||
for name, data in weights2:
|
||||
if not name.startswith(prefix):
|
||||
yield (name, data)
|
||||
|
||||
return get_prefix_weights(), get_other_weights()
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
torch.Tensor]]) -> Set[str]:
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
@@ -806,7 +806,7 @@ class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
(".self_attn.qkv_proj", ".self_attn.v_proj", "v"),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
updated_params: Set[str] = set()
|
||||
updated_params: set[str] = set()
|
||||
|
||||
# language_model is an Llama4ForCausalLM instance. We load it's
|
||||
# using llama4's load_weights routine.
|
||||
|
||||
Reference in New Issue
Block a user