Update deprecated Python 3.8 typing (#13971)

This commit is contained in:
Harry Mellor
2025-03-03 01:34:51 +00:00
committed by GitHub
parent bf33700ecd
commit cf069aa8aa
300 changed files with 2294 additions and 2347 deletions

View File

@@ -5,7 +5,8 @@ import copy
import itertools
import pickle as pkl
import time
from typing import Callable, Iterable, List, Optional, Tuple
from collections.abc import Iterable
from typing import Callable, Optional
import torch
import torch.utils.benchmark as TBenchmark
@@ -49,7 +50,7 @@ def bench_int8(
n: int,
label: str,
sub_label: str,
bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]:
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
"""Benchmark INT8-based kernels."""
assert dtype == torch.int8
a, b = make_rand_tensors(torch.int8, m, n, k)
@@ -101,7 +102,7 @@ def bench_fp8(
n: int,
label: str,
sub_label: str,
bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]:
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
"""Benchmark FP8-based kernels."""
assert dtype == torch.float8_e4m3fn
a, b = make_rand_tensors(torch.float8_e4m3fn, m, n, k)
@@ -180,7 +181,7 @@ def bench(dtype: torch.dtype,
n: int,
label: str,
sub_label: str,
bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]:
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
if dtype == torch.int8:
return bench_int8(dtype, m, k, n, label, sub_label, bench_kernels)
if dtype == torch.float8_e4m3fn:
@@ -195,8 +196,8 @@ def print_timers(timers: Iterable[TMeasurement]):
def run(dtype: torch.dtype,
MKNs: Iterable[Tuple[int, int, int]],
bench_kernels: Optional[List[str]] = None) -> Iterable[TMeasurement]:
MKNs: Iterable[tuple[int, int, int]],
bench_kernels: Optional[list[str]] = None) -> Iterable[TMeasurement]:
results = []
for m, k, n in MKNs:
timers = bench(dtype,
@@ -212,7 +213,7 @@ def run(dtype: torch.dtype,
def make_output(data: Iterable[TMeasurement],
MKNs: Iterable[Tuple[int, int, int]],
MKNs: Iterable[tuple[int, int, int]],
base_description: str,
timestamp=None):
print(f"== All Results {base_description} ====")
@@ -248,7 +249,7 @@ def run_model_bench(args):
for i, model in enumerate(args.models):
print(f"[{i}] {model}")
def model_shapes(model_name: str, tp_size: int) -> List[Tuple[int, int]]:
def model_shapes(model_name: str, tp_size: int) -> list[tuple[int, int]]:
KNs = []
for KN, tp_split_dim in copy.deepcopy(WEIGHT_SHAPES[model_name]):
KN[tp_split_dim] = KN[tp_split_dim] // tp_size