Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -5,7 +5,7 @@ import math
|
||||
from collections.abc import Iterable, Mapping, Sequence
|
||||
from dataclasses import dataclass, fields
|
||||
from functools import cached_property
|
||||
from typing import Annotated, Literal, Optional, Union
|
||||
from typing import Annotated, Literal
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -100,7 +100,7 @@ class PixtralImagePixelInputs(TensorSchema):
|
||||
type: Literal["pixel_values"] = "pixel_values"
|
||||
|
||||
images: Annotated[
|
||||
Union[torch.Tensor, list[torch.Tensor]],
|
||||
torch.Tensor | list[torch.Tensor],
|
||||
TensorShape("bn", 3, "h", "w", dynamic_dims={"h", "w"}),
|
||||
]
|
||||
|
||||
@@ -144,9 +144,9 @@ class PixtralProcessorAdapter:
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Optional[Union[TextInput, list[TextInput]]] = None,
|
||||
images: Optional[Union[ImageInput, list[ImageInput]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
text: TextInput | list[TextInput] | None = None,
|
||||
images: ImageInput | list[ImageInput] | None = None,
|
||||
return_tensors: str | TensorType | None = None,
|
||||
**kwargs,
|
||||
) -> Mapping[str, NestedTensors]:
|
||||
if text is None:
|
||||
@@ -203,12 +203,12 @@ class PixtralProcessingInfo(BaseProcessingInfo):
|
||||
def get_hf_processor(self) -> PixtralProcessorAdapter:
|
||||
return PixtralProcessorAdapter(self.get_tokenizer())
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"image": None}
|
||||
|
||||
def get_vision_config(
|
||||
self,
|
||||
processor: Optional[PixtralProcessorAdapter] = None,
|
||||
processor: PixtralProcessorAdapter | None = None,
|
||||
):
|
||||
if processor is None:
|
||||
processor = self.get_hf_processor()
|
||||
@@ -223,7 +223,7 @@ class PixtralProcessingInfo(BaseProcessingInfo):
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
processor: Optional[PixtralProcessorAdapter] = None,
|
||||
processor: PixtralProcessorAdapter | None = None,
|
||||
) -> int:
|
||||
if processor is None:
|
||||
processor = self.get_hf_processor()
|
||||
@@ -249,7 +249,7 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
@@ -270,7 +270,7 @@ class PixtralDummyInputsBuilder(BaseDummyInputsBuilder[PixtralProcessingInfo]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
||||
) -> ProcessorInputs:
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
|
||||
@@ -342,11 +342,11 @@ class PixtralMultiModalProcessor(BaseMultiModalProcessor[PixtralProcessingInfo])
|
||||
|
||||
def _cached_apply_hf_processor(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
prompt: str | list[int],
|
||||
mm_data_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
tokenization_kwargs: Mapping[str, object],
|
||||
mm_uuids: Optional[MultiModalUUIDDict] = None,
|
||||
mm_uuids: MultiModalUUIDDict | None = None,
|
||||
) -> tuple[list[int], MultiModalProcessingInfo, bool]:
|
||||
prompt_ids, mm_info, _ = super()._cached_apply_hf_processor(
|
||||
prompt=prompt,
|
||||
@@ -369,7 +369,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
|
||||
merge_by_field_config = True
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
||||
if modality.startswith("image"):
|
||||
return None
|
||||
|
||||
@@ -420,7 +420,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object
|
||||
) -> Optional[PixtralImagePixelInputs]:
|
||||
) -> PixtralImagePixelInputs | None:
|
||||
images = kwargs.pop("images", None)
|
||||
if images is None:
|
||||
return None
|
||||
@@ -472,10 +472,10 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
**kwargs: object,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
) -> torch.Tensor | IntermediateTensors:
|
||||
"""Run forward pass for pixtral."""
|
||||
if intermediate_tensors is not None:
|
||||
inputs_embeds = None
|
||||
@@ -489,7 +489,7 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> Optional[torch.Tensor]:
|
||||
) -> torch.Tensor | None:
|
||||
return self.language_model.compute_logits(hidden_states)
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||
@@ -717,7 +717,7 @@ class Transformer(nn.Module):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
freqs_cis: Optional[torch.Tensor],
|
||||
freqs_cis: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
for layer in self.layers:
|
||||
x = layer(x, mask=mask, freqs_cis=freqs_cis)
|
||||
@@ -759,7 +759,7 @@ class VisionTransformer(nn.Module):
|
||||
|
||||
head_dim = self.args.hidden_size // self.args.num_attention_heads
|
||||
assert head_dim % 2 == 0, "ROPE requires even head_dim"
|
||||
self._freqs_cis: Optional[torch.Tensor] = None
|
||||
self._freqs_cis: torch.Tensor | None = None
|
||||
|
||||
@property
|
||||
def max_patches_per_side(self) -> int:
|
||||
@@ -1015,7 +1015,7 @@ class PixtralHFMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PixtralVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
@@ -1049,7 +1049,7 @@ class PixtralHFAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PixtralVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
@@ -1084,7 +1084,7 @@ class PixtralHFAttention(nn.Module):
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
position_embeddings: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
batch, patches, _ = hidden_states.size()
|
||||
|
||||
qkv_states, _ = self.qkv_proj(hidden_states)
|
||||
@@ -1119,7 +1119,7 @@ class PixtralHFTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PixtralVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
@@ -1155,9 +1155,9 @@ class PixtralHFTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PixtralVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
*,
|
||||
num_hidden_layers_override: Optional[int] = None,
|
||||
num_hidden_layers_override: int | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -1202,10 +1202,10 @@ class PixtralHFVisionModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PixtralVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
*,
|
||||
num_hidden_layers_override: Optional[int] = None,
|
||||
require_post_norm: Optional[bool] = None,
|
||||
num_hidden_layers_override: int | None = None,
|
||||
require_post_norm: bool | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -1247,8 +1247,8 @@ class PixtralHFVisionModel(nn.Module):
|
||||
self,
|
||||
pixel_values: list[torch.Tensor],
|
||||
*,
|
||||
select_layers: Optional[list[int]] = None,
|
||||
feature_select_strategy: Optional[VisionFeatureSelectStrategy] = None,
|
||||
select_layers: list[int] | None = None,
|
||||
feature_select_strategy: VisionFeatureSelectStrategy | None = None,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""
|
||||
Args:
|
||||
|
||||
Reference in New Issue
Block a user