Update deprecated Python 3.8 typing (#13971)
This commit is contained in:
@@ -4,7 +4,7 @@ import argparse
|
||||
import time
|
||||
from datetime import datetime
|
||||
from itertools import product
|
||||
from typing import Any, Dict, List, Tuple, TypedDict
|
||||
from typing import Any, TypedDict
|
||||
|
||||
import ray
|
||||
import torch
|
||||
@@ -132,7 +132,7 @@ def benchmark_config(
|
||||
start_event = torch.cuda.Event(enable_timing=True)
|
||||
end_event = torch.cuda.Event(enable_timing=True)
|
||||
|
||||
latencies: List[float] = []
|
||||
latencies: list[float] = []
|
||||
for i in range(num_iters):
|
||||
prepare(i)
|
||||
torch.cuda.synchronize()
|
||||
@@ -175,8 +175,8 @@ def get_rocm_tuning_space(use_fp16):
|
||||
return param_ranges
|
||||
|
||||
|
||||
def get_configs_compute_bound(use_fp16) -> List[Dict[str, int]]:
|
||||
configs: List[BenchmarkConfig] = []
|
||||
def get_configs_compute_bound(use_fp16) -> list[dict[str, int]]:
|
||||
configs: list[BenchmarkConfig] = []
|
||||
|
||||
if current_platform.is_rocm():
|
||||
param_ranges = get_rocm_tuning_space(use_fp16)
|
||||
@@ -335,7 +335,7 @@ class BenchmarkWorker:
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
) -> Tuple[Dict[str, int], float]:
|
||||
) -> tuple[dict[str, int], float]:
|
||||
current_platform.seed_everything(self.seed)
|
||||
dtype_str = get_config_dtype_str(dtype,
|
||||
use_int8_w8a16=use_int8_w8a16,
|
||||
@@ -371,8 +371,8 @@ class BenchmarkWorker:
|
||||
dtype: torch.dtype,
|
||||
use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
search_space: List[Dict[str, int]],
|
||||
) -> Dict[str, int]:
|
||||
search_space: list[dict[str, int]],
|
||||
) -> dict[str, int]:
|
||||
best_config = None
|
||||
best_time = float("inf")
|
||||
if current_platform.is_rocm():
|
||||
@@ -434,7 +434,7 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
|
||||
}
|
||||
|
||||
|
||||
def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int,
|
||||
def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
|
||||
shard_intermediate_size: int, hidden_size: int, topk: int,
|
||||
dtype: torch.dtype, use_fp8_w8a8: bool,
|
||||
use_int8_w8a16: bool) -> None:
|
||||
@@ -498,7 +498,7 @@ def main(args: argparse.Namespace):
|
||||
num_gpus = int(ray.available_resources()["GPU"])
|
||||
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
|
||||
|
||||
def _distribute(method: str, inputs: List[Any]) -> List[Any]:
|
||||
def _distribute(method: str, inputs: list[Any]) -> list[Any]:
|
||||
outputs = []
|
||||
worker_idx = 0
|
||||
for input_args in inputs:
|
||||
|
||||
Reference in New Issue
Block a user