Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-12 17:51:31 +01:00
committed by GitHub
parent 9bb38130cb
commit 8fcaaf6a16
944 changed files with 9490 additions and 10121 deletions

View File

@@ -4,7 +4,7 @@ import contextlib
import os
from collections import namedtuple
from collections.abc import Callable
from typing import Any, Union
from typing import Any
import torch
@@ -138,7 +138,7 @@ def matmul_kernel_persistent(
def matmul_persistent(
a: torch.Tensor, b: torch.Tensor, bias: Union[torch.Tensor, None] = None
a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None
):
# Check constraints.
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
@@ -375,7 +375,7 @@ def mean_dim(
input: torch.Tensor,
dim: int,
keepdim: bool = False,
dtype: Union[torch.dtype, None] = None,
dtype: torch.dtype | None = None,
) -> torch.Tensor:
"""
Triton implementation of torch.mean with single dimension reduction.
@@ -475,9 +475,7 @@ def _log_softmax_batch_invariant(input, dim, _half_to_float):
return log_softmax(input, dim=dim)
def mean_batch_invariant(
input, dim, keepdim=False, dtype: Union[torch.dtype, None] = None
):
def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None = None):
assert dtype is None or dtype == torch.float32, f"unsupported dtype: {dtype}"
result = input.to(torch.float32)