Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -45,7 +45,6 @@
|
||||
from collections.abc import Sequence
|
||||
from copy import deepcopy
|
||||
from functools import cached_property
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -68,8 +67,8 @@ def multihead_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
q_cu_seqlens: Optional[torch.Tensor] = None,
|
||||
k_cu_seqlens: Optional[torch.Tensor] = None,
|
||||
q_cu_seqlens: torch.Tensor | None = None,
|
||||
k_cu_seqlens: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Multi-head attention using flash attention 2.
|
||||
|
||||
@@ -121,8 +120,8 @@ def sdpa_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
q_cu_seqlens: Optional[torch.Tensor] = None,
|
||||
k_cu_seqlens: Optional[torch.Tensor] = None,
|
||||
q_cu_seqlens: torch.Tensor | None = None,
|
||||
k_cu_seqlens: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""SDPA attention.
|
||||
|
||||
@@ -230,7 +229,7 @@ class MoonVisionPatchEmbed(nn.Module):
|
||||
self,
|
||||
out_dim: int,
|
||||
in_dim: int = 3,
|
||||
patch_size: Union[int, tuple[int, int]] = (14, 14),
|
||||
patch_size: int | tuple[int, int] = (14, 14),
|
||||
pos_emb_height: int = 14,
|
||||
pos_emb_width: int = 14,
|
||||
):
|
||||
@@ -460,7 +459,7 @@ class MoonVitEncoderLayer(nn.Module):
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rope_freqs_cis: Optional[torch.Tensor] = None,
|
||||
rope_freqs_cis: torch.Tensor | None = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -491,7 +490,7 @@ class MoonVitEncoderLayer(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rope_freqs_cis: Union[torch.Tensor, None] = None,
|
||||
rope_freqs_cis: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
|
||||
Reference in New Issue
Block a user