Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-12 17:51:31 +01:00
committed by GitHub
parent 9bb38130cb
commit 8fcaaf6a16
944 changed files with 9490 additions and 10121 deletions

View File

@@ -8,7 +8,6 @@
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
import warnings
from typing import Optional
import torch
from einops import rearrange
@@ -32,7 +31,7 @@ def chunk_gated_delta_rule_fwd(
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
cu_seqlens: torch.LongTensor | None = None,
):
g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
# obtain WY representation. u is actually the new v.
@@ -86,7 +85,7 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
scale: float,
initial_state: torch.Tensor,
output_final_state: bool,
cu_seqlens: Optional[torch.LongTensor] = None,
cu_seqlens: torch.LongTensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
):
if use_qk_l2norm_in_kernel:
@@ -119,7 +118,7 @@ def chunk_gated_delta_rule(
scale: float = None,
initial_state: torch.Tensor = None,
output_final_state: bool = False,
cu_seqlens: Optional[torch.LongTensor] = None,
cu_seqlens: torch.LongTensor | None = None,
head_first: bool = False,
use_qk_l2norm_in_kernel: bool = False,
):

View File

@@ -7,7 +7,6 @@
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
from typing import Optional
import torch
@@ -257,12 +256,12 @@ def chunk_gated_delta_rule_fwd_h(
k: torch.Tensor,
w: torch.Tensor,
u: torch.Tensor,
g: Optional[torch.Tensor] = None,
initial_state: Optional[torch.Tensor] = None,
g: torch.Tensor | None = None,
initial_state: torch.Tensor | None = None,
output_final_state: bool = False,
chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
save_new_value: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
cu_seqlens: torch.LongTensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, Hg, K, V = *k.shape, u.shape[-1]
H = u.shape[-2]

View File

@@ -9,7 +9,6 @@
# ruff: noqa: E501
from typing import Optional
import torch
@@ -144,9 +143,9 @@ def chunk_fwd_o(
k: torch.Tensor,
v: torch.Tensor,
h: torch.Tensor,
g: Optional[torch.Tensor] = None, # cumsum of log decay
scale: Optional[float] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
g: torch.Tensor | None = None, # cumsum of log decay
scale: float | None = None,
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64,
) -> torch.Tensor:
B, T, Hg, K, V = *q.shape, v.shape[-1]

View File

@@ -7,7 +7,6 @@
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
from typing import Optional
import torch
@@ -104,8 +103,8 @@ def chunk_scaled_dot_kkt_fwd_kernel(
def chunk_scaled_dot_kkt_fwd(
k: torch.Tensor,
beta: torch.Tensor,
g_cumsum: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.LongTensor] = None,
g_cumsum: torch.Tensor | None = None,
cu_seqlens: torch.LongTensor | None = None,
chunk_size: int = 64,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:

View File

@@ -8,7 +8,6 @@
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
import warnings
from typing import Optional
import torch
@@ -163,9 +162,9 @@ def chunk_local_cumsum_scalar(
g: torch.Tensor,
chunk_size: int,
reverse: bool = False,
cu_seqlens: Optional[torch.Tensor] = None,
cu_seqlens: torch.Tensor | None = None,
head_first: bool = False,
output_dtype: Optional[torch.dtype] = torch.float,
output_dtype: torch.dtype | None = torch.float,
) -> torch.Tensor:
if head_first:
B, H, T = g.shape
@@ -200,9 +199,9 @@ def chunk_local_cumsum_vector(
g: torch.Tensor,
chunk_size: int,
reverse: bool = False,
cu_seqlens: Optional[torch.Tensor] = None,
cu_seqlens: torch.Tensor | None = None,
head_first: bool = False,
output_dtype: Optional[torch.dtype] = torch.float,
output_dtype: torch.dtype | None = torch.float,
) -> torch.Tensor:
if head_first:
B, H, T, S = g.shape
@@ -248,9 +247,9 @@ def chunk_local_cumsum(
g: torch.Tensor,
chunk_size: int,
reverse: bool = False,
cu_seqlens: Optional[torch.Tensor] = None,
cu_seqlens: torch.Tensor | None = None,
head_first: bool = False,
output_dtype: Optional[torch.dtype] = torch.float,
output_dtype: torch.dtype | None = torch.float,
**kwargs,
) -> torch.Tensor:
if not head_first and g.shape[1] < g.shape[2]:

View File

@@ -7,7 +7,6 @@
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
from typing import Optional
import torch
@@ -169,9 +168,9 @@ def fused_recurrent_gated_delta_rule_fwd(
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
ssm_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
cu_seqlens: torch.LongTensor | None = None,
ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, H, K, V = *k.shape, v.shape[-1]
@@ -248,9 +247,9 @@ class FusedRecurrentFunction(torch.autograd.Function):
scale: float,
initial_state: torch.Tensor,
inplace_final_state: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
ssm_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
cu_seqlens: torch.LongTensor | None = None,
ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
):
o, final_state = fused_recurrent_gated_delta_rule_fwd(
@@ -280,9 +279,9 @@ def fused_recurrent_gated_delta_rule(
scale: float = None,
initial_state: torch.Tensor = None,
inplace_final_state: bool = True,
cu_seqlens: Optional[torch.LongTensor] = None,
ssm_state_indices: Optional[torch.Tensor] = None,
num_accepted_tokens: Optional[torch.Tensor] = None,
cu_seqlens: torch.LongTensor | None = None,
ssm_state_indices: torch.Tensor | None = None,
num_accepted_tokens: torch.Tensor | None = None,
use_qk_l2norm_in_kernel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
r"""

View File

@@ -8,7 +8,6 @@
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
import os
from typing import Optional
import torch
@@ -90,7 +89,7 @@ def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr):
def l2norm_fwd(
x: torch.Tensor, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None
x: torch.Tensor, eps: float = 1e-6, output_dtype: torch.dtype | None = None
):
x_shape_og = x.shape
x = x.view(-1, x.shape[-1])

View File

@@ -14,7 +14,6 @@
# The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
from functools import lru_cache
from typing import Optional
import torch
import torch.nn as nn
@@ -324,10 +323,10 @@ class LayerNormGated(nn.Module):
self,
hidden_size,
eps: float = 1e-5,
group_size: Optional[int] = None,
group_size: int | None = None,
norm_before_gate: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).
@@ -364,10 +363,10 @@ class RMSNormGated(nn.Module):
self,
hidden_size,
eps: float = 1e-5,
group_size: Optional[int] = None,
group_size: int | None = None,
norm_before_gate: bool = False,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
"""If group_size is not None, we do GroupNorm with each group having group_size elements.
group_size=None is equivalent to group_size=hidden_size (i.e. there's only 1 group).

View File

@@ -7,7 +7,6 @@
# the following copyright notice:
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
from typing import Optional
import torch
@@ -407,7 +406,7 @@ def merge_16x16_to_64x64_inverse_kernel(
@input_guard
def solve_tril(
A: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None,
cu_seqlens: torch.Tensor | None = None,
output_dtype: torch.dtype = torch.float,
) -> torch.Tensor:
"""

View File

@@ -11,8 +11,9 @@ import contextlib
import functools
import logging
import os
from collections.abc import Callable
from enum import Enum
from typing import Any, Callable, Literal, Optional
from typing import Any, Literal
import torch
@@ -43,7 +44,7 @@ def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]
A wrapped version of the input function with single-entry caching.
"""
cache_entries: tuple[Optional[tuple], Optional[dict], Any] = []
cache_entries: tuple[tuple | None, dict | None, Any] = []
cache_size = 4
@functools.wraps(fn)

View File

@@ -8,7 +8,6 @@
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# ruff: noqa: E501
from typing import Optional
import torch
@@ -123,7 +122,7 @@ def recompute_w_u_fwd(
beta: torch.Tensor,
g_cumsum: torch.Tensor,
A: torch.Tensor,
cu_seqlens: Optional[torch.LongTensor],
cu_seqlens: torch.LongTensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
B, T, Hg, K, V = *k.shape, v.shape[-1]
H = v.shape[-2]