Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-12 17:51:31 +01:00
committed by GitHub
parent 9bb38130cb
commit 8fcaaf6a16
944 changed files with 9490 additions and 10121 deletions

View File

@@ -25,8 +25,8 @@
import collections
import collections.abc
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Callable, Optional, TypedDict, Union, cast
from collections.abc import Callable, Iterable, Mapping, Sequence
from typing import Any, TypeAlias, TypedDict, cast
import numpy as np
import torch
@@ -66,7 +66,7 @@ from vllm.transformers_utils.configs.midashenglm import DashengConfig
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .utils import AutoWeightsLoader, init_vllm_registered_model, maybe_prefix
_Tuple2 = Union[int, tuple[int, int], Sequence[int]]
_Tuple2: TypeAlias = int | tuple[int, int] | Sequence[int]
def _resolve_tuple2(x: _Tuple2) -> tuple[int, int]:
@@ -105,7 +105,7 @@ class AudioPatchEmbed(nn.Module):
patch_stride: _Tuple2 = 16,
in_chans: int = 1,
embed_dim: int = 768,
norm_layer: Optional[Callable] = None,
norm_layer: Callable | None = None,
flatten: bool = False,
):
super().__init__()
@@ -151,9 +151,9 @@ class DashengMlp(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
hidden_features: int | None = None,
out_features: int | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
@@ -186,7 +186,7 @@ class DashengAttention(nn.Module):
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
quant_config: Optional[QuantizationConfig] = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
@@ -226,7 +226,7 @@ class DashengAttention(nn.Module):
prefix=f"{prefix}.proj",
)
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None):
B, N, C = x.shape
qkv, _ = self.qkv(x)
@@ -253,8 +253,8 @@ class DashengBlock(nn.Module):
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = False,
init_values: Optional[float] = None,
quant_config: Optional[QuantizationConfig] = None,
init_values: float | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
@@ -285,7 +285,7 @@ class DashengBlock(nn.Module):
def forward(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
mask: torch.Tensor | None = None,
) -> torch.Tensor:
x = x + self.ls1(self.attn(self.norm1(x), mask))
x = x + self.ls2(self.mlp(self.norm2(x)))
@@ -349,7 +349,7 @@ class DashengAudioTransformer(nn.Module):
def __init__(
self,
config: DashengConfig,
quant_config: Optional[QuantizationConfig] = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
@@ -393,7 +393,7 @@ class DashengAudioTransformer(nn.Module):
def forward_features(
self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
mask: torch.Tensor | None = None,
) -> torch.Tensor:
t = x.shape[-1]
x = x + self.time_pos_embed[:, :, :, :t]
@@ -418,8 +418,8 @@ class DashengAudioTransformer(nn.Module):
def forward(
self,
x: torch.Tensor,
x_length: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
x_length: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
x = self.front_end(x)
x = x.to(self.time_pos_embed.dtype)
target_length_in_patches = self.target_length // 4
@@ -462,8 +462,8 @@ class AudioProjectorSubsample(nn.Module):
in_dim: int,
out_dim: int,
downsample_rate=5,
dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
dtype: torch.dtype | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
):
super().__init__()
@@ -524,7 +524,7 @@ class MiDashengLMProcessingInfo(BaseProcessingInfo):
feature_extractor = hf_processor.feature_extractor
return feature_extractor
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"audio": None}
def get_min_audio_len(self):
@@ -550,7 +550,7 @@ class MiDashengLMDummyInputsBuilder(BaseDummyInputsBuilder[MiDashengLMProcessing
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
num_audios = mm_counts.get("audio", 0)
@@ -689,7 +689,7 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
}
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("audio"):
return "<|audio_bos|><|AUDIO|><|audio_eos|>"
@@ -750,7 +750,7 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
def _parse_and_validate_audio_input(
self, **kwargs: object
) -> Optional[MiDashengLMAudioInputs]:
) -> MiDashengLMAudioInputs | None:
input_values = kwargs.pop("input_values", None)
audio_length = kwargs.pop("audio_length", None)
@@ -820,10 +820,10 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
) -> torch.Tensor | IntermediateTensors:
if intermediate_tensors is not None:
inputs_embeds = None
elif inputs_embeds is None:
@@ -845,7 +845,7 @@ class MiDashengLMModel(nn.Module, SupportsMultiModal, SupportsPP):
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> Optional[torch.Tensor]:
) -> torch.Tensor | None:
return self.decoder.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: