Update deprecated Python 3.8 typing (#13971)
This commit is contained in:
@@ -1,7 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as benchmark
|
||||
from benchmark_shapes import WEIGHT_SHAPES
|
||||
@@ -31,7 +29,7 @@ ACT_ORDER_OPTS = [False, True]
|
||||
K_FULL_OPTS = [False, True]
|
||||
|
||||
|
||||
def bench_run(results: List[benchmark.Measurement], model: str,
|
||||
def bench_run(results: list[benchmark.Measurement], model: str,
|
||||
act_order: bool, is_k_full: bool, quant_type: ScalarType,
|
||||
group_size: int, size_m: int, size_k: int, size_n: int):
|
||||
label = "Quant Matmul"
|
||||
@@ -221,7 +219,7 @@ def main(args):
|
||||
for i, model in enumerate(args.models):
|
||||
print(f"[{i}] {model}")
|
||||
|
||||
results: List[benchmark.Measurement] = []
|
||||
results: list[benchmark.Measurement] = []
|
||||
|
||||
for model in args.models:
|
||||
for layer in WEIGHT_SHAPES[model]:
|
||||
|
||||
Reference in New Issue
Block a user