Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -8,7 +8,7 @@
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from typing import Annotated, Literal, Optional, Union
|
||||
from typing import Annotated, Literal, TypeAlias
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -96,14 +96,14 @@ class SkyworkR1VImageEmbeddingInputs(TensorSchema):
|
||||
type: Literal["image_embeds"] = "image_embeds"
|
||||
|
||||
data: Annotated[
|
||||
Union[torch.Tensor, list[torch.Tensor]],
|
||||
torch.Tensor | list[torch.Tensor],
|
||||
TensorShape("ni", "ifs", "hs"),
|
||||
]
|
||||
|
||||
|
||||
SkyworkR1VImageInputs = Union[
|
||||
SkyworkR1VImagePixelInputs, SkyworkR1VImageEmbeddingInputs
|
||||
]
|
||||
SkyworkR1VImageInputs: TypeAlias = (
|
||||
SkyworkR1VImagePixelInputs | SkyworkR1VImageEmbeddingInputs
|
||||
)
|
||||
|
||||
|
||||
# adapted from https://huggingface.co/Skywork/Skywork-R1V-38B/
|
||||
@@ -284,9 +284,9 @@ class SkyworkR1VProcessor:
|
||||
config: PretrainedConfig,
|
||||
tokenizer: AnyTokenizer,
|
||||
*,
|
||||
min_dynamic_patch: Optional[int] = None,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
min_dynamic_patch: int | None = None,
|
||||
max_dynamic_patch: int | None = None,
|
||||
dynamic_image_size: bool | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -324,7 +324,7 @@ class SkyworkR1VProcessor:
|
||||
def get_image_repl(
|
||||
self,
|
||||
feature_size: int,
|
||||
num_patches: Optional[int],
|
||||
num_patches: int | None,
|
||||
) -> PromptUpdateDetails[str]:
|
||||
repl_features = IMG_CONTEXT * feature_size
|
||||
repl_full = IMG_START + repl_features + IMG_END
|
||||
@@ -334,10 +334,10 @@ class SkyworkR1VProcessor:
|
||||
def resolve_min_max_num(
|
||||
self,
|
||||
*,
|
||||
min_dynamic_patch: Optional[int] = None,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
use_thumbnail: Optional[bool] = None,
|
||||
min_dynamic_patch: int | None = None,
|
||||
max_dynamic_patch: int | None = None,
|
||||
dynamic_image_size: bool | None = None,
|
||||
use_thumbnail: bool | None = None,
|
||||
) -> tuple[int, int]:
|
||||
min_dynamic_patch = (
|
||||
self.min_dynamic_patch if min_dynamic_patch is None else min_dynamic_patch
|
||||
@@ -362,10 +362,10 @@ class SkyworkR1VProcessor:
|
||||
def resolve_target_ratios(
|
||||
self,
|
||||
*,
|
||||
min_dynamic_patch: Optional[int] = None,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
use_thumbnail: Optional[bool] = None,
|
||||
min_dynamic_patch: int | None = None,
|
||||
max_dynamic_patch: int | None = None,
|
||||
dynamic_image_size: bool | None = None,
|
||||
use_thumbnail: bool | None = None,
|
||||
) -> list[tuple[int, int]]:
|
||||
min_num, max_num = self.resolve_min_max_num(
|
||||
min_dynamic_patch=min_dynamic_patch,
|
||||
@@ -399,9 +399,9 @@ class SkyworkR1VProcessor:
|
||||
def _images_to_pixel_values_lst(
|
||||
self,
|
||||
images: list[Image.Image],
|
||||
min_dynamic_patch: Optional[int] = None,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
min_dynamic_patch: int | None = None,
|
||||
max_dynamic_patch: int | None = None,
|
||||
dynamic_image_size: bool | None = None,
|
||||
) -> list[torch.Tensor]:
|
||||
min_num, max_num = self.resolve_min_max_num(
|
||||
min_dynamic_patch=min_dynamic_patch,
|
||||
@@ -423,12 +423,12 @@ class SkyworkR1VProcessor:
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Optional[Union[str, list[str]]] = None,
|
||||
images: Optional[Union[Image.Image, list[Image.Image]]] = None,
|
||||
min_dynamic_patch: Optional[int] = None,
|
||||
max_dynamic_patch: Optional[int] = None,
|
||||
dynamic_image_size: Optional[bool] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
text: str | list[str] | None = None,
|
||||
images: Image.Image | list[Image.Image] | None = None,
|
||||
min_dynamic_patch: int | None = None,
|
||||
max_dynamic_patch: int | None = None,
|
||||
dynamic_image_size: bool | None = None,
|
||||
return_tensors: str | TensorType | None = None,
|
||||
) -> BatchFeature:
|
||||
if text is None:
|
||||
text = []
|
||||
@@ -479,7 +479,7 @@ class SkyworkR1VProcessingInfo(BaseProcessingInfo):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"image": None}
|
||||
|
||||
def get_num_image_tokens(
|
||||
@@ -487,7 +487,7 @@ class SkyworkR1VProcessingInfo(BaseProcessingInfo):
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
processor: Optional[SkyworkR1VProcessor],
|
||||
processor: SkyworkR1VProcessor | None,
|
||||
) -> int:
|
||||
if processor is None:
|
||||
processor = self.get_hf_processor()
|
||||
@@ -532,7 +532,7 @@ class SkyworkR1VDummyInputsBuilder(BaseDummyInputsBuilder[SkyworkR1VProcessingIn
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
||||
) -> MultiModalDataDict:
|
||||
target_width, target_height = self.info.get_image_size_with_most_features()
|
||||
num_images = mm_counts.get("image", 0)
|
||||
@@ -650,7 +650,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
merge_by_field_config = True
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
||||
if modality.startswith("image"):
|
||||
return "<image>"
|
||||
|
||||
@@ -715,7 +715,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
def _init_vision_model(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig],
|
||||
quant_config: QuantizationConfig | None,
|
||||
*,
|
||||
is_mono: bool,
|
||||
prefix: str,
|
||||
@@ -784,7 +784,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object
|
||||
) -> Optional[SkyworkR1VImageInputs]:
|
||||
) -> SkyworkR1VImageInputs | None:
|
||||
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
|
||||
image_num_patches = kwargs.pop("image_num_patches", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
@@ -818,7 +818,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
def _process_image_input(
|
||||
self,
|
||||
image_input: SkyworkR1VImageInputs,
|
||||
) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]:
|
||||
) -> torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor, ...]:
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["data"]
|
||||
|
||||
@@ -864,9 +864,9 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
multimodal_embeddings: MultiModalEmbeddings | None = None,
|
||||
*,
|
||||
is_multimodal: Optional[torch.Tensor] = None,
|
||||
is_multimodal: torch.Tensor | None = None,
|
||||
handle_oov_mm_token: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if multimodal_embeddings is not None and len(multimodal_embeddings) > 0:
|
||||
@@ -887,8 +887,8 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
**kwargs: object,
|
||||
) -> IntermediateTensors:
|
||||
if intermediate_tensors is not None:
|
||||
@@ -913,7 +913,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> Optional[torch.Tensor]:
|
||||
) -> torch.Tensor | None:
|
||||
return self.language_model.compute_logits(hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
|
||||
|
||||
Reference in New Issue
Block a user