Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-12 17:51:31 +01:00
committed by GitHub
parent 9bb38130cb
commit 8fcaaf6a16
944 changed files with 9490 additions and 10121 deletions

View File

@@ -5,7 +5,7 @@ import math
from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass, fields
from functools import cached_property
from typing import Annotated, Literal, Optional, Union
from typing import Annotated, Literal
import torch
import torch.nn as nn
@@ -100,7 +100,7 @@ class PixtralImagePixelInputs(TensorSchema):
type: Literal["pixel_values"] = "pixel_values"
images: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
torch.Tensor | list[torch.Tensor],
TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"}),
]
@@ -144,9 +144,9 @@ class PixtralProcessorAdapter:
def __call__(
self,
text: Optional[Union[TextInput, list[TextInput]]] = None,
images: Optional[Union[ImageInput, list[ImageInput]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
text: TextInput | list[TextInput] | None = None,
images: ImageInput | list[ImageInput] | None = None,
return_tensors: str | TensorType | None = None,
**kwargs,
) -> Mapping[str, NestedTensors]:
if text is None:
@@ -203,12 +203,12 @@ class PixtralProcessingInfo(BaseProcessingInfo):
def get_hf_processor(self) -> PixtralProcessorAdapter:
return PixtralProcessorAdapter(self.get_tokenizer())
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None}
def get_vision_config(
self,
processor: Optional[PixtralProcessorAdapter] = None,
processor: PixtralProcessorAdapter | None = None,
):
if processor is None:
processor = self.get_hf_processor()
@@ -223,7 +223,7 @@ class PixtralProcessingInfo(BaseProcessingInfo):
*,
image_width: int,
image_height: int,
processor: Optional[PixtralProcessorAdapter] = None,
processor: PixtralProcessorAdapter | None = None,
) -> int:
if processor is None:
processor = self.get_hf_processor()
@@ -249,7 +249,7 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
@@ -270,7 +270,7 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> ProcessorInputs:
tokenizer = self.info.get_tokenizer()
@@ -342,11 +342,11 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo])
def _cached_apply_hf_processor(
self,
prompt: Union[str, list[int]],
prompt: str | list[int],
mm_data_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
tokenization_kwargs: Mapping[str, object],
mm_uuids: Optional[MultiModalUUIDDict] = None,
mm_uuids: MultiModalUUIDDict | None = None,
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
prompt=prompt,
@@ -369,7 +369,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
merge_by_field_config = True
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
return None
@@ -420,7 +420,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
def _parse_and_validate_image_input(
self, **kwargs: object
) -> Optional[PixtralImagePixelInputs]:
) -> PixtralImagePixelInputs | None:
images = kwargs.pop("images", None)
if images is None:
return None
@@ -472,10 +472,10 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
) -> torch.Tensor | IntermediateTensors:
"""Run forward pass for pixtral."""
if intermediate_tensors is not None:
inputs_embeds = None
@@ -489,7 +489,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> Optional[torch.Tensor]:
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
@@ -717,7 +717,7 @@ class Transformer(nn.Module):
self,
x: torch.Tensor,
mask: torch.Tensor,
freqs_cis: Optional[torch.Tensor],
freqs_cis: torch.Tensor | None,
) -> torch.Tensor:
for layer in self.layers:
x = layer(x, mask=mask, freqs_cis=freqs_cis)
@@ -759,7 +759,7 @@ class VisionTransformer(nn.Module):
head_dim = self.args.hidden_size // self.args.num_attention_heads
assert head_dim % 2 == 0, "ROPE requires even head_dim"
self._freqs_cis: Optional[torch.Tensor] = None
self._freqs_cis: torch.Tensor | None = None
@property
def max_patches_per_side(self) -> int:
@@ -1015,7 +1015,7 @@ class PixtralHFMLP(nn.Module):
def __init__(
self,
config: PixtralVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
quant_config: QuantizationConfig | None = None,
*,
prefix: str = "",
) -> None:
@@ -1049,7 +1049,7 @@ class PixtralHFAttention(nn.Module):
def __init__(
self,
config: PixtralVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
quant_config: QuantizationConfig | None = None,
*,
prefix: str = "",
) -> None:
@@ -1084,7 +1084,7 @@ class PixtralHFAttention(nn.Module):
hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
position_embeddings: torch.Tensor,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, torch.Tensor | None]:
batch, patches, _ = hidden_states.size()
qkv_states, _ = self.qkv_proj(hidden_states)
@@ -1119,7 +1119,7 @@ class PixtralHFTransformerBlock(nn.Module):
def __init__(
self,
config: PixtralVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
quant_config: QuantizationConfig | None = None,
*,
prefix: str = "",
) -> None:
@@ -1155,9 +1155,9 @@ class PixtralHFTransformer(nn.Module):
def __init__(
self,
config: PixtralVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
quant_config: QuantizationConfig | None = None,
*,
num_hidden_layers_override: Optional[int] = None,
num_hidden_layers_override: int | None = None,
prefix: str = "",
) -> None:
super().__init__()
@@ -1202,10 +1202,10 @@ class PixtralHFVisionModel(nn.Module):
def __init__(
self,
config: PixtralVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
quant_config: QuantizationConfig | None = None,
*,
num_hidden_layers_override: Optional[int] = None,
require_post_norm: Optional[bool] = None,
num_hidden_layers_override: int | None = None,
require_post_norm: bool | None = None,
prefix: str = "",
) -> None:
super().__init__()
@@ -1247,8 +1247,8 @@ class PixtralHFVisionModel(nn.Module):
self,
pixel_values: list[torch.Tensor],
*,
select_layers: Optional[list[int]] = None,
feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None,
select_layers: list[int] | None = None,
feature_select_strategy: VisionFeatureSelectStrategy | None = None,
) -> tuple[torch.Tensor, ...]:
"""
Args: