Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-12 17:51:31 +01:00
committed by GitHub
parent 9bb38130cb
commit 8fcaaf6a16
944 changed files with 9490 additions and 10121 deletions

View File

@@ -4,7 +4,7 @@ import math
from collections.abc import Iterable, Mapping, Sequence
from itertools import product
from math import ceil, sqrt
from typing import Annotated, Any, Literal, Optional, Union
from typing import Annotated, Any, Literal, TypeAlias
import numpy as np
import torch
@@ -71,7 +71,7 @@ class Step3VLImagePixelInputs(TensorSchema):
type: Literal["pixel_values"]
pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]
patch_pixel_values: Annotated[
Optional[torch.Tensor], TensorShape("bnp", 3, "hp", "wp")
torch.Tensor | None, TensorShape("bnp", 3, "hp", "wp")
]
num_patches: Annotated[torch.Tensor, TensorShape("bn")]
@@ -88,7 +88,7 @@ class Step3VLImageEmbeddingInputs(TensorSchema):
data: Annotated[torch.Tensor, TensorShape("bn", "f", "h")]
Step3VLImageInputs = Union[Step3VLImagePixelInputs, Step3VLImageEmbeddingInputs]
Step3VLImageInputs: TypeAlias = Step3VLImagePixelInputs | Step3VLImageEmbeddingInputs
ImageWithPatches = tuple[Image.Image, list[Image.Image], list[int] | None]
@@ -409,7 +409,7 @@ class Step3VLProcessor:
self,
num_images: int,
num_patches: int,
patch_new_line_idx: Optional[list[bool]],
patch_new_line_idx: list[bool] | None,
) -> tuple[str, list[int]]:
if num_patches > 0:
patch_repl, patch_repl_ids = self._get_patch_repl(
@@ -438,9 +438,9 @@ class Step3VLProcessor:
def __call__(
self,
text: Optional[Union[str, list[str]]] = None,
images: Optional[Union[Image.Image, list[Image.Image]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
text: str | list[str] | None = None,
images: Image.Image | list[Image.Image] | None = None,
return_tensors: str | TensorType | None = None,
) -> BatchFeature:
if text is None:
text = []
@@ -513,7 +513,7 @@ class Step3VLProcessingInfo(BaseProcessingInfo):
self.get_tokenizer(),
)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None}
def get_max_image_tokens(self) -> int:
@@ -556,7 +556,7 @@ class Step3VLDummyInputsBuilder(BaseDummyInputsBuilder[Step3VLProcessingInfo]):
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
target_width, target_height = self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0)
@@ -716,7 +716,7 @@ class Step3VisionAttention(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
):
@@ -778,7 +778,7 @@ class Step3VisionMLP(nn.Module):
def __init__(
self,
config,
quant_config: Optional[QuantizationConfig] = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
):
@@ -813,7 +813,7 @@ class Step3VisionEncoderLayer(nn.Module):
def __init__(
self,
config: Step3VisionEncoderConfig,
quant_config: Optional[QuantizationConfig] = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
):
@@ -848,7 +848,7 @@ class Step3VisionEncoder(nn.Module):
def __init__(
self,
config: Step3VisionEncoderConfig,
quant_config: Optional[QuantizationConfig] = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
):
@@ -881,7 +881,7 @@ class Step3VisionTransformer(nn.Module):
def __init__(
self,
config: Step3VisionEncoderConfig,
quant_config: Optional[QuantizationConfig] = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
):
@@ -927,7 +927,7 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
supports_encoder_tp_data = True
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
return "<im_patch>"
@@ -994,7 +994,7 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
def _parse_and_validate_image_input(
self, **kwargs: object
) -> Optional[Step3VLImageInputs]:
) -> Step3VLImageInputs | None:
pixel_values = kwargs.pop("pixel_values", None)
patch_pixel_values = kwargs.pop("patch_pixel_values", None)
num_patches = kwargs.pop("num_patches", None)
@@ -1085,9 +1085,9 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
is_multimodal: torch.Tensor | None = None,
# Multi-modal token ID may exceed vocab size
handle_oov_mm_token: bool = True,
) -> torch.Tensor:
@@ -1106,10 +1106,10 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
) -> torch.Tensor | IntermediateTensors:
if intermediate_tensors is not None:
inputs_embeds = None
elif inputs_embeds is None:
@@ -1130,7 +1130,7 @@ class Step3VLForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> Optional[torch.Tensor]:
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):