Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-12 17:51:31 +01:00
committed by GitHub
parent 9bb38130cb
commit 8fcaaf6a16
944 changed files with 9490 additions and 10121 deletions

View File

@@ -22,10 +22,10 @@
# limitations under the License.
"""Inference-only Qwen2.5-Omni model (thinker part)."""
from collections.abc import Iterable, Mapping, Sequence
from collections.abc import Callable, Iterable, Mapping, Sequence
from copy import copy
from functools import partial
from typing import Annotated, Any, Callable, Literal, Optional, Union
from typing import Annotated, Any, Literal
import torch
import torch.nn as nn
@@ -125,7 +125,7 @@ class Qwen2_5OmniAudioFeatureInputs(TensorSchema):
type: Literal["audio_features"]
input_features: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
torch.Tensor | list[torch.Tensor],
TensorShape("nmb", "tsl"),
]
@@ -191,7 +191,7 @@ class Qwen2_5OmniThinkerMultiModalDataParser(Qwen2VLMultiModalDataParser):
def _parse_audio_data(
self,
data: Union[dict[str, torch.Tensor], ModalityData[ImageItem]],
data: dict[str, torch.Tensor] | ModalityData[ImageItem],
) -> ModalityDataItems[Any, Any]:
if isinstance(data, dict):
return DictEmbeddingItems(
@@ -225,7 +225,7 @@ class Qwen2_5OmniThinkerProcessingInfo(
assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": None, "image": None, "video": None}
@@ -253,7 +253,7 @@ class Qwen2_5OmniThinkerDummyInputsBuilder(
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
num_audios = mm_counts.get("audio", 0)
num_images = mm_counts.get("image", 0)
@@ -420,7 +420,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
cls,
thinker_config: PretrainedConfig,
audio_len: int,
video_grid_thw: Union[list[int], torch.Tensor],
video_grid_thw: list[int] | torch.Tensor,
video_second_per_grid_t: float,
) -> list[int]:
"""Get video prompt updates when `use_audio_in_video` is True.
@@ -580,7 +580,7 @@ class Qwen2_5OmniThinkerMultiModalProcessor(
def _apply_hf_processor_main(
self,
prompt: Union[str, list[int]],
prompt: str | list[int],
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
@@ -665,7 +665,7 @@ class Qwen2_5OmniConditionalGenerationMixin:
def _parse_and_validate_audio_input(
self, **kwargs: object
) -> Optional[Qwen2_5OmniAudioFeatureInputs]:
) -> Qwen2_5OmniAudioFeatureInputs | None:
input_audio_features = kwargs.pop("input_audio_features", None)
audio_feature_lengths = kwargs.pop("audio_feature_lengths", None)
feature_attention_mask = kwargs.pop("feature_attention_mask", None)
@@ -693,7 +693,7 @@ class Qwen2_5OmniConditionalGenerationMixin:
def _parse_and_validate_image_input(
self,
**kwargs: dict[str, Any],
) -> Optional[Qwen2_5_VLImageInputs]:
) -> Qwen2_5_VLImageInputs | None:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
image_grid_thw = kwargs.pop("image_grid_thw", None)
@@ -743,7 +743,7 @@ class Qwen2_5OmniConditionalGenerationMixin:
def _parse_and_validate_video_input(
self,
**kwargs: dict[str, Any],
) -> Optional[Qwen2_5_VLVideoInputs]:
) -> Qwen2_5_VLVideoInputs | None:
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
video_embeds = kwargs.pop("video_embeds", None)
video_grid_thw = kwargs.pop("video_grid_thw", None)
@@ -892,7 +892,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
}
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
return "<|vision_start|><|IMAGE|><|vision_end|>"
if modality.startswith("video"):
@@ -991,12 +991,12 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
second_per_grid_ts: Optional[list[float]] = None,
image_grid_thw: list[list[int]] | torch.Tensor,
video_grid_thw: list[list[int]] | torch.Tensor,
second_per_grid_ts: list[float] | None = None,
context_len: int = 0,
seq_len: Optional[int] = None,
audio_feature_lengths: Optional[torch.Tensor] = None,
seq_len: int | None = None,
audio_feature_lengths: torch.Tensor | None = None,
use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value (Qwen2.5-Omni version).
@@ -1225,9 +1225,9 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
# This is to satisfy the type checker for each overload
@@ -1241,7 +1241,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
handle_oov_mm_token=handle_oov_mm_token,
)
def get_multimodal_embeddings_v0(self, **kwargs: object) -> Optional[NestedTensors]:
def get_multimodal_embeddings_v0(self, **kwargs: object) -> NestedTensors | None:
audio_input = self._parse_and_validate_audio_input(**kwargs)
image_input = self._parse_and_validate_image_input(**kwargs)
video_input = self._parse_and_validate_video_input(**kwargs)
@@ -1266,10 +1266,10 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
) -> torch.Tensor | IntermediateTensors:
if intermediate_tensors is not None:
inputs_embeds = None
@@ -1281,7 +1281,7 @@ class Qwen2_5OmniThinkerForConditionalGeneration(
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> Optional[torch.Tensor]:
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: