[mypy] Fix wrong type annotations related to tuple (#25660)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-09-25 21:00:45 +08:00
committed by GitHub
parent 1e9a77e037
commit 2f17117606
9 changed files with 25 additions and 20 deletions

View File

@@ -60,7 +60,7 @@ TENSORS_SHAPES_FN = [
@torch.inference_mode()
def test_rotary_embedding(
is_neox_style: bool,
tensor_shape_fn: Callable[[int, int, int, int], tuple[int]],
tensor_shape_fn: Callable[[int, int, int, int], tuple[int, ...]],
batch_size: int,
seq_len: int,
num_heads: int,

View File

@@ -165,7 +165,7 @@ def onednn_gemm_test_helper(primitive_cache_size: int,
def test_onednn_int8_scaled_gemm(
n: int,
k: int,
m_list: tuple[int],
m_list: tuple[int, ...],
per_tensor_a_scale: bool,
per_tensor_b_scale: bool,
use_bias: bool,
@@ -196,7 +196,7 @@ def test_onednn_int8_scaled_gemm(
def test_onednn_gemm(
n: int,
k: int,
m_list: tuple[int],
m_list: tuple[int, ...],
use_bias: bool,
use_stride: bool,
dtype: torch.dtype,