Update deprecated type hinting in models (#18132)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-15 06:06:50 +01:00
committed by GitHub
parent 83f74c698f
commit 26d0419309
130 changed files with 971 additions and 901 deletions

View File

@@ -2,8 +2,7 @@
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import (Final, List, Literal, Optional, Protocol, Set, Tuple,
TypedDict, Union)
from typing import Final, Literal, Optional, Protocol, TypedDict, Union
import torch
import torch.nn as nn
@@ -471,8 +470,8 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
return data
def _validate_image_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, list[torch.Tensor]]:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
@@ -530,8 +529,8 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
raise AssertionError("This line should be unreachable.")
def _validate_video_pixel_values(
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
self, data: Union[torch.Tensor, list[torch.Tensor]]
) -> Union[torch.Tensor, list[torch.Tensor]]:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
@@ -557,7 +556,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
A legal video input should have the following dimensions:
{
"pixel_values_videos" :
List[b, Tensor(nb_frames, nb_channels, height, width)]
list[b, Tensor(nb_frames, nb_channels, height, width)]
}
"""
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
@@ -706,7 +705,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
def _process_image_pixels(
self,
inputs: LlavaOnevisionImagePixelInputs,
) -> Union[torch.Tensor, List[torch.Tensor]]:
) -> Union[torch.Tensor, list[torch.Tensor]]:
assert self.vision_tower is not None
pixel_values = inputs["pixel_values"]
@@ -735,7 +734,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
def _process_image_input(
self,
image_input: LlavaOnevisionImageInputs,
) -> Union[torch.Tensor, List[torch.Tensor]]:
) -> Union[torch.Tensor, list[torch.Tensor]]:
if image_input["type"] == "image_embeds":
return [image_input["data"]]
@@ -948,7 +947,7 @@ class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal,
return self.language_model.compute_logits(hidden_states,
sampling_metadata)
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)