Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-12 17:51:31 +01:00
committed by GitHub
parent 9bb38130cb
commit 8fcaaf6a16
944 changed files with 9490 additions and 10121 deletions

View File

@@ -27,9 +27,9 @@
"""Inference-only GLM-4V model compatible with HuggingFace weights."""
import math
from collections.abc import Iterable, Mapping, Sequence
from collections.abc import Callable, Iterable, Mapping, Sequence
from functools import partial
from typing import Annotated, Any, Callable, Literal, Optional, Union
from typing import Annotated, Any, Literal, TypeAlias
import numpy as np
import torch
@@ -140,7 +140,7 @@ class Glm4vImageEmbeddingInputs(TensorSchema):
image_grid_thw: Annotated[torch.Tensor, TensorShape("n", 3)]
Glm4vImageInputs = Union[Glm4vImagePixelInputs, Glm4vImageEmbeddingInputs]
Glm4vImageInputs: TypeAlias = Glm4vImagePixelInputs | Glm4vImageEmbeddingInputs
class Glm4vVideoPixelInputs(TensorSchema):
@@ -176,7 +176,7 @@ class Glm4vVideoEmbeddingInputs(TensorSchema):
video_grid_thw: Annotated[torch.Tensor, TensorShape("f", 3)]
Glm4vVideoInputs = Union[Glm4vVideoPixelInputs, Glm4vVideoEmbeddingInputs]
Glm4vVideoInputs: TypeAlias = Glm4vVideoPixelInputs | Glm4vVideoEmbeddingInputs
# ==== Vision Encoder ==== #
@@ -187,7 +187,7 @@ class Glm4vVisionMLP(nn.Module):
in_features: int,
hidden_features: int,
bias: bool = False,
quant_config: Optional[QuantizationConfig] = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
):
@@ -244,7 +244,7 @@ class Glm4vVisionAttention(nn.Module):
embed_dim: int,
num_heads: int,
projection_size: int,
quant_config: Optional[QuantizationConfig] = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
@@ -334,8 +334,8 @@ class Glm4vVisionAttention(nn.Module):
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: Optional[int] = None, # Only used for Flash Attention
seqlens: Optional[list[int]] = None, # Only used for xFormers
max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)
@@ -413,8 +413,8 @@ class Glm4vVisionBlock(nn.Module):
dim: int,
num_heads: int,
mlp_hidden_dim: int,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
quant_config: Optional[QuantizationConfig] = None,
norm_layer: Callable[[int], nn.Module] | None = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
@@ -445,8 +445,8 @@ class Glm4vVisionBlock(nn.Module):
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rotary_pos_emb: torch.Tensor,
max_seqlen: Optional[int] = None, # Only used for Flash Attention
seqlens: Optional[list[int]] = None, # Only used for xFormers
max_seqlen: int | None = None, # Only used for Flash Attention
seqlens: list[int] | None = None, # Only used for xFormers
) -> torch.Tensor:
x_attn = self.attn(
self.norm1(x),
@@ -495,7 +495,7 @@ class Glm4vPatchMerger(nn.Module):
self,
d_model: int,
context_dim: int,
quant_config: Optional[QuantizationConfig] = None,
quant_config: QuantizationConfig | None = None,
bias: bool = False,
prefix: str = "",
use_data_parallel: bool = False,
@@ -693,7 +693,7 @@ class Glm4vVisionTransformer(nn.Module):
self,
vision_config: Glm4vVisionConfig,
norm_eps: float = 1e-6,
quant_config: Optional[QuantizationConfig] = None,
quant_config: QuantizationConfig | None = None,
prefix: str = "",
use_data_parallel: bool = False,
) -> None:
@@ -809,7 +809,7 @@ class Glm4vVisionTransformer(nn.Module):
def compute_attn_mask_seqlen(
self,
cu_seqlens: torch.Tensor,
) -> tuple[Optional[int], Optional[list[int]]]:
) -> tuple[int | None, list[int] | None]:
max_seqlen, seqlens = None, None
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
if (
@@ -904,7 +904,7 @@ class Glm4vProcessingInfo(BaseProcessingInfo):
def get_tokenizer(self):
return self.ctx.tokenizer
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
return {"image": None, "video": 1}
def get_image_processor(self, **kwargs: object) -> Glm4vImageProcessor:
@@ -1141,7 +1141,7 @@ class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]):
self,
seq_len: int,
mm_counts: Mapping[str, int],
mm_options: Optional[Mapping[str, BaseDummyOptions]] = None,
mm_options: Mapping[str, BaseDummyOptions] | None = None,
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
num_videos = mm_counts.get("video", 0)
@@ -1177,7 +1177,7 @@ class Glm4vDummyInputsBuilder(BaseDummyInputsBuilder[Glm4vProcessingInfo]):
height: int,
num_frames: int,
num_videos: int,
overrides: Optional[VideoDummyOptions] = None,
overrides: VideoDummyOptions | None = None,
) -> list[VideoItem]:
if overrides:
if overrides.num_frames:
@@ -1419,7 +1419,7 @@ class Glm4vForConditionalGeneration(
supports_encoder_tp_data = True
@classmethod
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
def get_placeholder_str(cls, modality: str, i: int) -> str | None:
if modality.startswith("image"):
return "<|begin_of_image|><|image|><|end_of_image|>"
if modality.startswith("video"):
@@ -1465,7 +1465,7 @@ class Glm4vForConditionalGeneration(
def _parse_and_validate_image_input(
self, **kwargs: object
) -> Optional[Glm4vImageInputs]:
) -> Glm4vImageInputs | None:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
image_grid_thw = kwargs.pop("image_grid_thw", None)
@@ -1489,7 +1489,7 @@ class Glm4vForConditionalGeneration(
def _parse_and_validate_video_input(
self, **kwargs: object
) -> Optional[Glm4vVideoInputs]:
) -> Glm4vVideoInputs | None:
pixel_values_videos = kwargs.pop("pixel_values_videos", None)
video_embeds = kwargs.pop("video_embeds", None)
video_grid_thw = kwargs.pop("video_grid_thw", None)
@@ -1594,7 +1594,7 @@ class Glm4vForConditionalGeneration(
def get_multimodal_embeddings(
self, **kwargs: object
) -> Optional[MultiModalEmbeddings]:
) -> MultiModalEmbeddings | None:
mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
if not mm_input_by_modality:
return None
@@ -1619,10 +1619,10 @@ class Glm4vForConditionalGeneration(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs: object,
) -> Union[torch.Tensor, IntermediateTensors]:
) -> torch.Tensor | IntermediateTensors:
"""Run forward pass for GLM-4V.
Args:
@@ -1652,7 +1652,7 @@ class Glm4vForConditionalGeneration(
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> Optional[torch.Tensor]:
) -> torch.Tensor | None:
return self.language_model.compute_logits(hidden_states)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: