Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-12 17:51:31 +01:00
committed by GitHub
parent 9bb38130cb
commit 8fcaaf6a16
944 changed files with 9490 additions and 10121 deletions

View File

@@ -6,7 +6,7 @@ from collections.abc import Iterable, Mapping, Sequence
from dataclasses import dataclass
from functools import cached_property, partial
from itertools import islice
from typing import Annotated, Optional, Union
from typing import Annotated
import numpy as np
import torch
@@ -104,18 +104,18 @@ class MolmoImageInputs(TensorSchema):
"""
images: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
torch.Tensor | list[torch.Tensor],
TensorShape("bn", "nc", "np", "pd", dynamic_dims={"nc"}),
]
# Number of crops may vary per batch and image, so pass it as a list.
image_masks: Annotated[
Optional[Union[torch.Tensor, list[torch.Tensor]]],
torch.Tensor | list[torch.Tensor] | None,
TensorShape("bn", "nc", "np", dynamic_dims={"nc"}),
]
image_input_idx: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
torch.Tensor | list[torch.Tensor],
TensorShape("bn", "nc", "tp", dynamic_dims={"nc"}),
]
# An index tensor that maps image features to their corresponding patch tokens.
@@ -151,7 +151,7 @@ class ViTMLP(nn.Module):
def __init__(
self,
config: VisionBackboneConfig,
quant_config: Optional[QuantizationConfig] = None,
quant_config: QuantizationConfig | None = None,
):
super().__init__()
self.w1 = ColumnParallelLinear(
@@ -185,7 +185,7 @@ class MultiHeadDotProductAttention(nn.Module):
config: VisionBackboneConfig,
use_bias: bool = True,
nlayers: int = 1,
quant_config: Optional[QuantizationConfig] = None,
quant_config: QuantizationConfig | None = None,
):
super().__init__()
@@ -238,7 +238,7 @@ class MultiHeadDotProductAttention(nn.Module):
)
def forward(
self, inputs_q: torch.Tensor, inputs_kv: Optional[torch.Tensor] = None
self, inputs_q: torch.Tensor, inputs_kv: torch.Tensor | None = None
) -> torch.Tensor:
if inputs_kv is not None:
inputs_k = inputs_kv
@@ -263,7 +263,7 @@ class ResidualAttentionBlock(nn.Module):
def __init__(
self,
config: VisionBackboneConfig,
quant_config: Optional[QuantizationConfig] = None,
quant_config: QuantizationConfig | None = None,
):
super().__init__()
self.attention = MultiHeadDotProductAttention(config, quant_config=quant_config)
@@ -289,7 +289,7 @@ class BlockCollection(nn.Module):
def __init__(
self,
config: VisionBackboneConfig,
quant_config: Optional[QuantizationConfig] = None,
quant_config: QuantizationConfig | None = None,
):
super().__init__()
self.resblocks = nn.ModuleList(
@@ -317,7 +317,7 @@ class VisionTransformer(nn.Module):
def __init__(
self,
config: VisionBackboneConfig,
quant_config: Optional[QuantizationConfig] = None,
quant_config: QuantizationConfig | None = None,
):
super().__init__()
scale = config.image_emb_dim**-0.5
@@ -367,7 +367,7 @@ class VisionTransformer(nn.Module):
return x
def forward(
self, x: torch.Tensor, patch_num: Optional[int] = None
self, x: torch.Tensor, patch_num: int | None = None
) -> list[torch.Tensor]:
"""
: param x: (batch_size, num_patch, n_pixels)
@@ -396,8 +396,8 @@ class MolmoAttention(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
@@ -432,9 +432,9 @@ class MolmoAttention(nn.Module):
quant_config=quant_config,
)
self.tp_rank: Optional[int] = None
self.k_norm: Optional[nn.Module] = None
self.q_norm: Optional[nn.Module] = None
self.tp_rank: int | None = None
self.k_norm: nn.Module | None = None
self.q_norm: nn.Module | None = None
if config.attention_layer_norm:
self.tp_rank = get_tensor_model_parallel_rank()
self.k_norm = RMSNorm(
@@ -503,8 +503,8 @@ class LanguageModelMLP(nn.Module):
def __init__(
self,
config: PretrainedConfig,
input_dim: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
input_dim: int | None = None,
quant_config: QuantizationConfig | None = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@@ -542,8 +542,8 @@ class ImageProjectorMLP(nn.Module):
def __init__(
self,
config: PretrainedConfig,
input_dim: Optional[int] = None,
quant_config: Optional[QuantizationConfig] = None,
input_dim: int | None = None,
quant_config: QuantizationConfig | None = None,
) -> None:
super().__init__()
self.hidden_size = config.hidden_size
@@ -580,8 +580,8 @@ class MolmoDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
cache_config: CacheConfig | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> None:
super().__init__()
@@ -604,8 +604,8 @@ class MolmoDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
residual: torch.Tensor | None,
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
# Self Attention
if residual is None:
residual = hidden_states
@@ -627,8 +627,8 @@ class MolmoDecoderNormAfterLayer(MolmoDecoderLayer):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
residual: torch.Tensor | None,
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
# Self Attention
residual = hidden_states
hidden_states = self.self_attn(
@@ -654,7 +654,7 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
self,
config: PretrainedConfig,
vision_config: VisionBackboneConfig,
quant_config: Optional[QuantizationConfig] = None,
quant_config: QuantizationConfig | None = None,
) -> None:
super().__init__()
self.vit_layers = VIT_LAYERS
@@ -849,8 +849,8 @@ class MolmoModel(nn.Module, SupportsQuant):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
@@ -1064,7 +1064,7 @@ class MolmoProcessorWrapper:
return image_token_length_h
@property
def message_format(self) -> Optional[str]:
def message_format(self) -> str | None:
return "role"
@property
@@ -1145,9 +1145,9 @@ class MolmoProcessorWrapper:
def __call__(
self,
text: Optional[Union[TextInput, list[TextInput]]] = None,
images: Optional[Union[ImageInput, list[ImageInput]]] = None,
return_tensors: Optional[Union[str, TensorType]] = None,
text: TextInput | list[TextInput] | None = None,
images: ImageInput | list[ImageInput] | None = None,
return_tensors: str | TensorType | None = None,
**kwargs,
) -> BatchFeature:
outputs = self.processor.process( # type: ignore
@@ -1189,7 +1189,7 @@ class MolmoProcessingInfo(BaseProcessingInfo):
processor = self.ctx.get_hf_processor(**kwargs)
return MolmoProcessorWrapper(processor)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None}
def get_num_image_tokens(
@@ -1197,7 +1197,7 @@ class MolmoProcessingInfo(BaseProcessingInfo):
*,
image_width: int,
image_height: int,
processor: Optional[MolmoProcessorWrapper],
processor: MolmoProcessorWrapper | None,
) -> int:
if processor is None:
processor = self.get_hf_processor()
@@ -1250,7 +1250,7 @@ class MolmoDummyInputsBuilder(BaseDummyInputsBuilder[MolmoProcessingInfo]):
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
target_width, target_height = self.info.get_image_size_with_most_features()
num_images = mm_counts.get("image", 0)
@@ -1398,7 +1398,7 @@ class MolmoForCausalLM(
}
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
return None
@@ -1442,7 +1442,7 @@ class MolmoForCausalLM(
def _parse_and_validate_image_input(
self,
**kwargs: object,
) -> Optional[MolmoImageInputs]:
) -> MolmoImageInputs | None:
images = kwargs.pop("images", None)
image_masks = kwargs.pop("image_masks", None)
image_input_idx = kwargs.pop("image_input_idx", None)
@@ -1522,8 +1522,8 @@ class MolmoForCausalLM(
self,
input_ids: torch.LongTensor,
positions: torch.LongTensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> torch.Tensor:
if intermediate_tensors is not None: