Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-12 17:51:31 +01:00
committed by GitHub
parent 9bb38130cb
commit 8fcaaf6a16
944 changed files with 9490 additions and 10121 deletions

View File

@@ -8,7 +8,7 @@
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from collections.abc import Iterable, Mapping, Sequence
from typing import Annotated, Literal, Optional, Union
from typing import Annotated, Literal, TypeAlias
import torch
import torch.nn as nn
@@ -96,14 +96,14 @@ class SkyworkR1VImageEmbeddingInputs(TensorSchema):
type: Literal["image_embeds"] = "image_embeds"
data: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
torch.Tensor | list[torch.Tensor],
TensorShape("ni", "ifs", "hs"),
]
SkyworkR1VImageInputs = Union[
SkyworkR1VImagePixelInputs, SkyworkR1VImageEmbeddingInputs
]
SkyworkR1VImageInputs: TypeAlias = (
SkyworkR1VImagePixelInputs | SkyworkR1VImageEmbeddingInputs
)
# adapted from https://huggingface.co/Skywork/Skywork-R1V-38B/
@@ -284,9 +284,9 @@ class SkyworkR1VProcessor:
config: PretrainedConfig,
tokenizer: AnyTokenizer,
*,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
min_dynamic_patch: int | None = None,
max_dynamic_patch: int | None = None,
dynamic_image_size: bool | None = None,
) -> None:
super().__init__()
@@ -324,7 +324,7 @@ class SkyworkR1VProcessor:
def get_image_repl(
self,
feature_size: int,
num_patches: Optional[int],
num_patches: int | None,
) -> PromptUpdateDetails[str]:
repl_features = IMG_CONTEXT * feature_size
repl_full = IMG_START + repl_features + IMG_END
@@ -334,10 +334,10 @@ class SkyworkR1VProcessor:
def resolve_min_max_num(
self,
*,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
use_thumbnail: Optional[bool] = None,
min_dynamic_patch: int | None = None,
max_dynamic_patch: int | None = None,
dynamic_image_size: bool | None = None,
use_thumbnail: bool | None = None,
) -> tuple[int, int]:
min_dynamic_patch = (
self.min_dynamic_patch if min_dynamic_patch is None else min_dynamic_patch
@@ -362,10 +362,10 @@ class SkyworkR1VProcessor:
def resolve_target_ratios(
self,
*,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
use_thumbnail: Optional[bool] = None,
min_dynamic_patch: int | None = None,
max_dynamic_patch: int | None = None,
dynamic_image_size: bool | None = None,
use_thumbnail: bool | None = None,
) -> list[tuple[int, int]]:
min_num, max_num = self.resolve_min_max_num(
min_dynamic_patch=min_dynamic_patch,
@@ -399,9 +399,9 @@ class SkyworkR1VProcessor:
def _images_to_pixel_values_lst(
self,
images: list[Image.Image],
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
min_dynamic_patch: int | None = None,
max_dynamic_patch: int | None = None,
dynamic_image_size: bool | None = None,
) -> list[torch.Tensor]:
min_num, max_num = self.resolve_min_max_num(
min_dynamic_patch=min_dynamic_patch,
@@ -423,12 +423,12 @@ class SkyworkR1VProcessor:
def __call__(
self,
text: Optional[Union[str, list[str]]] = None,
images: Optional[Union[Image.Image, list[Image.Image]]] = None,
min_dynamic_patch: Optional[int] = None,
max_dynamic_patch: Optional[int] = None,
dynamic_image_size: Optional[bool] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
text: str | list[str] | None = None,
images: Image.Image | list[Image.Image] | None = None,
min_dynamic_patch: int | None = None,
max_dynamic_patch: int | None = None,
dynamic_image_size: bool | None = None,
return_tensors: str | TensorType | None = None,
) -> BatchFeature:
if text is None:
text = []
@@ -479,7 +479,7 @@ class SkyworkR1VProcessingInfo(BaseProcessingInfo):
**kwargs,
)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None}
def get_num_image_tokens(
@@ -487,7 +487,7 @@ class SkyworkR1VProcessingInfo(BaseProcessingInfo):
*,
image_width: int,
image_height: int,
processor: Optional[SkyworkR1VProcessor],
processor: SkyworkR1VProcessor | None,
) -> int:
if processor is None:
processor = self.get_hf_processor()
@@ -532,7 +532,7 @@ class SkyworkR1VDummyInputsBuilder(BaseDummyInputsBuilder[SkyworkR1VProcessingIn
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
target_width, target_height = self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0)
@@ -650,7 +650,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
merge_by_field_config = True
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
return "<image>"
@@ -715,7 +715,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
def _init_vision_model(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig],
quant_config: QuantizationConfig | None,
*,
is_mono: bool,
prefix: str,
@@ -784,7 +784,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
def _parse_and_validate_image_input(
self, **kwargs: object
) -> Optional[SkyworkR1VImageInputs]:
) -> SkyworkR1VImageInputs | None:
pixel_values_flat = kwargs.pop("pixel_values_flat", None)
image_num_patches = kwargs.pop("image_num_patches", None)
image_embeds = kwargs.pop("image_embeds", None)
@@ -818,7 +818,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
def _process_image_input(
self,
image_input: SkyworkR1VImageInputs,
) -> Union[torch.Tensor, list[torch.Tensor], tuple[torch.Tensor, ...]]:
) -> torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor, ...]:
if image_input["type"] == "image_embeds":
return image_input["data"]
@@ -864,9 +864,9 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
if multimodal_embeddings is not None and len(multimodal_embeddings) > 0:
@@ -887,8 +887,8 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> IntermediateTensors:
if intermediate_tensors is not None:
@@ -913,7 +913,7 @@ class SkyworkR1VChatModel(nn.Module, SupportsMultiModal, SupportsPP):
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> Optional[torch.Tensor]:
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: