Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-12 17:51:31 +01:00
committed by GitHub
parent 9bb38130cb
commit 8fcaaf6a16
944 changed files with 9490 additions and 10121 deletions

View File

@@ -45,7 +45,6 @@
from collections.abc import Sequence
from copy import deepcopy
from functools import cached_property
from typing import Optional, Union
import torch
import torch.nn as nn
@@ -68,8 +67,8 @@ def multihead_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
q_cu_seqlens: Optional[torch.Tensor] = None,
k_cu_seqlens: Optional[torch.Tensor] = None,
q_cu_seqlens: torch.Tensor | None = None,
k_cu_seqlens: torch.Tensor | None = None,
) -> torch.Tensor:
"""Multi-head attention using flash attention 2.
@@ -121,8 +120,8 @@ def sdpa_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
q_cu_seqlens: Optional[torch.Tensor] = None,
k_cu_seqlens: Optional[torch.Tensor] = None,
q_cu_seqlens: torch.Tensor | None = None,
k_cu_seqlens: torch.Tensor | None = None,
) -> torch.Tensor:
"""SDPA attention.
@@ -230,7 +229,7 @@ class MoonVisionPatchEmbed(nn.Module):
self,
out_dim: int,
in_dim: int = 3,
patch_size: Union[int, tuple[int, int]] = (14, 14),
patch_size: int | tuple[int, int] = (14, 14),
pos_emb_height: int = 14,
pos_emb_width: int = 14,
):
@@ -460,7 +459,7 @@ class MoonVitEncoderLayer(nn.Module):
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
rope_freqs_cis: Optional[torch.Tensor] = None,
rope_freqs_cis: torch.Tensor | None = None,
):
"""
Args:
@@ -491,7 +490,7 @@ class MoonVitEncoderLayer(nn.Module):
self,
hidden_states: torch.Tensor,
cu_seqlens: torch.Tensor,
rope_freqs_cis: Union[torch.Tensor, None] = None,
rope_freqs_cis: torch.Tensor | None = None,
) -> torch.Tensor:
"""
Args: