Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-12 17:51:31 +01:00
committed by GitHub
parent 9bb38130cb
commit 8fcaaf6a16
944 changed files with 9490 additions and 10121 deletions

View File

@@ -24,10 +24,10 @@
# limitations under the License.
"""Inference-only Qwen3VL model compatible with HuggingFace weights."""
from collections.abc import Iterable, Mapping, Sequence
from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import partial
from itertools import islice
from typing import Any, Callable, Optional, Union
from typing import Any
import numpy as np
import torch
@@ -151,7 +151,7 @@ class Qwen3_VisionMLP(nn.Module):
hidden_features: int,
bias: bool = False,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
quant_config: Optional[QuantizationConfig] = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
):
@@ -188,8 +188,8 @@ class Qwen3_VisionBlock(nn.Module):
num_heads: int,
mlp_hidden_dim: int,
act_fn: Callable[[torch.Tensor], torch.Tensor] = F.silu,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
quant_config: Optional[QuantizationConfig] = None,
norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
attn_backend: _Backend = _Backend.TORCH_SDPA,
@@ -225,8 +225,8 @@ class Qwen3_VisionBlock(nn.Module):
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: Optional[int] = None, # Only used for Flash Attention
seqlens: Optional[list[int]] = None, # Only used for xFormers
max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
x = x + self.attn(
self.norm1(x),
@@ -245,10 +245,10 @@ class Qwen3_VisionPatchMerger(nn.Module):
self,
d_model: int,
context_dim: int,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
norm_layer: Callable[[int], nn.Module] | None = None,
spatial_merge_size: int = 2,
use_postshuffle_norm: bool = False,
quant_config: Optional[QuantizationConfig] = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
@@ -297,7 +297,7 @@ class Qwen3_VisionTransformer(nn.Module):
self,
vision_config: Qwen3VLVisionConfig,
norm_eps: float = 1e-6,
quant_config: Optional[QuantizationConfig] = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
@@ -511,7 +511,7 @@ class Qwen3_VisionTransformer(nn.Module):
def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
) -> tuple[Optional[int], Optional[list[int]]]:
) -> tuple[int | None, list[int] | None]:
max_seqlen, seqlens = None, None
if (
self.attn_backend == _Backend.FLASH_ATTN
@@ -625,9 +625,7 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
image_height: int,
num_frames: int = 2,
do_resize: bool = True,
image_processor: Optional[
Union[Qwen2VLImageProcessorFast, Qwen3VLVideoProcessor]
],
image_processor: Qwen2VLImageProcessorFast | Qwen3VLVideoProcessor | None,
) -> tuple[ImageSize, int]:
if image_processor is None and num_frames > 1:
image_processor = self.get_video_processor()
@@ -726,8 +724,8 @@ class Qwen3VLProcessingInfo(Qwen2VLProcessingInfo):
self,
metadata: dict[str, Any],
out_item: MultiModalKwargsItem,
do_sample_frames: Optional[bool] = None,
sampled_fps: Optional[float] = None,
do_sample_frames: bool | None = None,
sampled_fps: float | None = None,
) -> list[int]:
video_processor = self.get_video_processor()
merge_size = video_processor.merge_size
@@ -778,7 +776,7 @@ class Qwen3VLDummyInputsBuilder(BaseDummyInputsBuilder[Qwen3VLProcessingInfo]):
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
@@ -1096,11 +1094,11 @@ class Qwen3LLMModel(Qwen3Model):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
# args for deepstack
deepstack_input_embeds: Optional[IntermediateTensors] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
deepstack_input_embeds: IntermediateTensors | None = None,
) -> torch.Tensor | IntermediateTensors:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
@@ -1201,7 +1199,7 @@ class Qwen3VLForConditionalGeneration(
)
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
return "<|vision_start|><|image_pad|><|vision_end|>"
if modality.startswith("video"):
@@ -1314,7 +1312,7 @@ class Qwen3VLForConditionalGeneration(
def _parse_and_validate_image_input(
self, **kwargs: object
) -> Optional[Qwen2_5_VLImageInputs]:
) -> Qwen2_5_VLImageInputs | None:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
image_grid_thw = kwargs.pop("image_grid_thw", None)
@@ -1363,7 +1361,7 @@ class Qwen3VLForConditionalGeneration(
def _parse_and_validate_video_input(
self, **kwargs: object
) -> Optional[Qwen2_5_VLVideoInputs]:
) -> Qwen2_5_VLVideoInputs | None:
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
video_embeds = kwargs.pop("video_embeds", None)
video_grid_thw = kwargs.pop("video_grid_thw", None)
@@ -1486,12 +1484,12 @@ class Qwen3VLForConditionalGeneration(
cls,
input_tokens: list[int],
hf_config: PretrainedConfig,
image_grid_thw: Union[list[list[int]], torch.Tensor],
video_grid_thw: Union[list[list[int]], torch.Tensor],
image_grid_thw: list[list[int]] | torch.Tensor,
video_grid_thw: list[list[int]] | torch.Tensor,
context_len: int = 0,
seq_len: Optional[int] = None,
second_per_grid_ts: Optional[list[float]] = None,
audio_feature_lengths: Optional[torch.Tensor] = None,
seq_len: int | None = None,
second_per_grid_ts: list[float] | None = None,
audio_feature_lengths: torch.Tensor | None = None,
use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value."""
@@ -1596,7 +1594,7 @@ class Qwen3VLForConditionalGeneration(
def get_multimodal_embeddings(
self, **kwargs: object
) -> Optional[MultiModalEmbeddings]:
) -> MultiModalEmbeddings | None:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
if not mm_input_by_modality:
return None
@@ -1661,9 +1659,9 @@ class Qwen3VLForConditionalGeneration(
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: Optional[torch.Tensor] = None,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self._get_text_embeddings(
@@ -1710,10 +1708,10 @@ class Qwen3VLForConditionalGeneration(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
) -> torch.Tensor | IntermediateTensors:
"""Run forward pass for Qwen3VL.
Args:
@@ -1769,7 +1767,7 @@ class Qwen3VLForConditionalGeneration(
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> Optional[torch.Tensor]:
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: