Update Optional[x] -> x | None and Union[x, y] to x | y (#26633)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -4,7 +4,7 @@ import contextlib
|
||||
import os
|
||||
from collections import namedtuple
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
@@ -138,7 +138,7 @@ def matmul_kernel_persistent(
|
||||
|
||||
|
||||
def matmul_persistent(
|
||||
a: torch.Tensor, b: torch.Tensor, bias: Union[torch.Tensor, None] = None
|
||||
a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None = None
|
||||
):
|
||||
# Check constraints.
|
||||
assert a.shape[1] == b.shape[0], "Incompatible dimensions"
|
||||
@@ -375,7 +375,7 @@ def mean_dim(
|
||||
input: torch.Tensor,
|
||||
dim: int,
|
||||
keepdim: bool = False,
|
||||
dtype: Union[torch.dtype, None] = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Triton implementation of torch.mean with single dimension reduction.
|
||||
@@ -475,9 +475,7 @@ def _log_softmax_batch_invariant(input, dim, _half_to_float):
|
||||
return log_softmax(input, dim=dim)
|
||||
|
||||
|
||||
def mean_batch_invariant(
|
||||
input, dim, keepdim=False, dtype: Union[torch.dtype, None] = None
|
||||
):
|
||||
def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype | None = None):
|
||||
assert dtype is None or dtype == torch.float32, f"unsupported dtype: {dtype}"
|
||||
|
||||
result = input.to(torch.float32)
|
||||
|
||||
Reference in New Issue
Block a user