Update deprecated type hinting in models (#18132)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union
|
||||
from typing import Any, Literal, Optional, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -392,7 +392,7 @@ class Phi4MMImageEncoder(nn.Module):
|
||||
|
||||
class Phi4MMImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: Union[torch.Tensor, List[torch.Tensor]]
|
||||
data: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""
|
||||
Shape:
|
||||
`(batch_size * num_images, 1 + num_patches, num_channels, height, width)`
|
||||
@@ -417,7 +417,7 @@ class Phi4MMImagePixelInputs(TypedDict):
|
||||
|
||||
class Phi4MMImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
data: Union[torch.Tensor, List[torch.Tensor]]
|
||||
data: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""Shape: `(batch_size * num_images, image_feature_size, hidden_size)`
|
||||
|
||||
`hidden_size` must match the hidden size of language model backbone.
|
||||
@@ -426,7 +426,7 @@ class Phi4MMImageEmbeddingInputs(TypedDict):
|
||||
|
||||
class Phi4MMAudioFeatureInputs(TypedDict):
|
||||
type: Literal["audio_features"]
|
||||
data: Union[torch.Tensor, List[torch.Tensor]]
|
||||
data: Union[torch.Tensor, list[torch.Tensor]]
|
||||
"""Shape: `(batch_size * num_audios, 80, M)"""
|
||||
|
||||
|
||||
@@ -1031,7 +1031,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
return audio_embeds
|
||||
|
||||
def _parse_and_validate_image_input(self,
|
||||
**kwargs: object) -> Optional[Dict]:
|
||||
**kwargs: object) -> Optional[dict]:
|
||||
input_image_embeds: NestedTensors = kwargs.get("input_image_embeds")
|
||||
if input_image_embeds is None:
|
||||
return None
|
||||
@@ -1238,7 +1238,7 @@ class Phi4MMForCausalLM(nn.Module, SupportsLoRA, SupportsMultiModal):
|
||||
sampling_metadata)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str,
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> None:
|
||||
weights = ((name, data) for name, data in weights
|
||||
if "lora" not in name)
|
||||
|
||||
Reference in New Issue
Block a user