Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -6,7 +6,7 @@ from collections.abc import Iterable, Mapping, Sequence
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property, partial
|
||||
from itertools import islice
|
||||
from typing import Annotated, Optional, Union
|
||||
from typing import Annotated
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -104,18 +104,18 @@ class MolmoImageInputs(TensorSchema):
|
||||
"""
|
||||
|
||||
images: Annotated[
|
||||
Union[torch.Tensor, list[torch.Tensor]],
|
||||
torch.Tensor | list[torch.Tensor],
|
||||
TensorShape("bn", "nc", "np", "pd", dynamic_dims={"nc"}),
|
||||
]
|
||||
# Number of crops may vary per batch and image, so pass it as a list.
|
||||
|
||||
image_masks: Annotated[
|
||||
Optional[Union[torch.Tensor, list[torch.Tensor]]],
|
||||
torch.Tensor | list[torch.Tensor] | None,
|
||||
TensorShape("bn", "nc", "np", dynamic_dims={"nc"}),
|
||||
]
|
||||
|
||||
image_input_idx: Annotated[
|
||||
Union[torch.Tensor, list[torch.Tensor]],
|
||||
torch.Tensor | list[torch.Tensor],
|
||||
TensorShape("bn", "nc", "tp", dynamic_dims={"nc"}),
|
||||
]
|
||||
# An index tensor that maps image features to their corresponding patch tokens.
|
||||
@@ -151,7 +151,7 @@ class ViTMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: VisionBackboneConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.w1 = ColumnParallelLinear(
|
||||
@@ -185,7 +185,7 @@ class MultiHeadDotProductAttention(nn.Module):
|
||||
config: VisionBackboneConfig,
|
||||
use_bias: bool = True,
|
||||
nlayers: int = 1,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -238,7 +238,7 @@ class MultiHeadDotProductAttention(nn.Module):
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, inputs_q: torch.Tensor, inputs_kv: Optional[torch.Tensor] = None
|
||||
self, inputs_q: torch.Tensor, inputs_kv: torch.Tensor | None = None
|
||||
) -> torch.Tensor:
|
||||
if inputs_kv is not None:
|
||||
inputs_k = inputs_kv
|
||||
@@ -263,7 +263,7 @@ class ResidualAttentionBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: VisionBackboneConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.attention = MultiHeadDotProductAttention(config, quant_config=quant_config)
|
||||
@@ -289,7 +289,7 @@ class BlockCollection(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: VisionBackboneConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.resblocks = nn.ModuleList(
|
||||
@@ -317,7 +317,7 @@ class VisionTransformer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: VisionBackboneConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
scale = config.image_emb_dim**-0.5
|
||||
@@ -367,7 +367,7 @@ class VisionTransformer(nn.Module):
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self, x: torch.Tensor, patch_num: Optional[int] = None
|
||||
self, x: torch.Tensor, patch_num: int | None = None
|
||||
) -> list[torch.Tensor]:
|
||||
"""
|
||||
: param x: (batch_size, num_patch, n_pixels)
|
||||
@@ -396,8 +396,8 @@ class MolmoAttention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -432,9 +432,9 @@ class MolmoAttention(nn.Module):
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
self.tp_rank: Optional[int] = None
|
||||
self.k_norm: Optional[nn.Module] = None
|
||||
self.q_norm: Optional[nn.Module] = None
|
||||
self.tp_rank: int | None = None
|
||||
self.k_norm: nn.Module | None = None
|
||||
self.q_norm: nn.Module | None = None
|
||||
if config.attention_layer_norm:
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.k_norm = RMSNorm(
|
||||
@@ -503,8 +503,8 @@ class LanguageModelMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
input_dim: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
input_dim: int | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@@ -542,8 +542,8 @@ class ImageProjectorMLP(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
input_dim: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
input_dim: int | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
@@ -580,8 +580,8 @@ class MolmoDecoderLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -604,8 +604,8 @@ class MolmoDecoderLayer(nn.Module):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
|
||||
residual: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
@@ -627,8 +627,8 @@ class MolmoDecoderNormAfterLayer(MolmoDecoderLayer):
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor, torch.Tensor]]]:
|
||||
residual: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
|
||||
# Self Attention
|
||||
residual = hidden_states
|
||||
hidden_states = self.self_attn(
|
||||
@@ -654,7 +654,7 @@ class MolmoVisionBackbone(nn.Module, SupportsQuant):
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
vision_config: VisionBackboneConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.vit_layers = VIT_LAYERS
|
||||
@@ -849,8 +849,8 @@ class MolmoModel(nn.Module, SupportsQuant):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
if get_pp_group().is_first_rank:
|
||||
if inputs_embeds is not None:
|
||||
@@ -1064,7 +1064,7 @@ class MolmoProcessorWrapper:
|
||||
return image_token_length_h
|
||||
|
||||
@property
|
||||
def message_format(self) -> Optional[str]:
|
||||
def message_format(self) -> str | None:
|
||||
return "role"
|
||||
|
||||
@property
|
||||
@@ -1145,9 +1145,9 @@ class MolmoProcessorWrapper:
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
text: Optional[Union[TextInput, list[TextInput]]] = None,
|
||||
images: Optional[Union[ImageInput, list[ImageInput]]] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
text: TextInput | list[TextInput] | None = None,
|
||||
images: ImageInput | list[ImageInput] | None = None,
|
||||
return_tensors: str | TensorType | None = None,
|
||||
**kwargs,
|
||||
) -> BatchFeature:
|
||||
outputs = self.processor.process( # type: ignore
|
||||
@@ -1189,7 +1189,7 @@ class MolmoProcessingInfo(BaseProcessingInfo):
|
||||
processor = self.ctx.get_hf_processor(**kwargs)
|
||||
return MolmoProcessorWrapper(processor)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
|
||||
return {"image": None}
|
||||
|
||||
def get_num_image_tokens(
|
||||
@@ -1197,7 +1197,7 @@ class MolmoProcessingInfo(BaseProcessingInfo):
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
processor: Optional[MolmoProcessorWrapper],
|
||||
processor: MolmoProcessorWrapper | None,
|
||||
) -> int:
|
||||
if processor is None:
|
||||
processor = self.get_hf_processor()
|
||||
@@ -1250,7 +1250,7 @@ class MolmoDummyInputsBuilder(BaseDummyInputsBuilder[MolmoProcessingInfo]):
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
|
||||
mm_options: Mapping[str, BaseDummyOptions] | None = None,
|
||||
) -> MultiModalDataDict:
|
||||
target_width, target_height = self.info.get_image_size_with_most_features()
|
||||
num_images = mm_counts.get("image", 0)
|
||||
@@ -1398,7 +1398,7 @@ class MolmoForCausalLM(
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
|
||||
if modality.startswith("image"):
|
||||
return None
|
||||
|
||||
@@ -1442,7 +1442,7 @@ class MolmoForCausalLM(
|
||||
def _parse_and_validate_image_input(
|
||||
self,
|
||||
**kwargs: object,
|
||||
) -> Optional[MolmoImageInputs]:
|
||||
) -> MolmoImageInputs | None:
|
||||
images = kwargs.pop("images", None)
|
||||
image_masks = kwargs.pop("image_masks", None)
|
||||
image_input_idx = kwargs.pop("image_input_idx", None)
|
||||
@@ -1522,8 +1522,8 @@ class MolmoForCausalLM(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
positions: torch.LongTensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
intermediate_tensors: IntermediateTensors | None = None,
|
||||
inputs_embeds: torch.Tensor | None = None,
|
||||
**kwargs: object,
|
||||
) -> torch.Tensor:
|
||||
if intermediate_tensors is not None:
|
||||
|
||||
Reference in New Issue
Block a user