Update deprecated Python 3.8 typing (#13971)

This commit is contained in:
Harry Mellor
2025-03-03 01:34:51 +00:00
committed by GitHub
parent bf33700ecd
commit cf069aa8aa
300 changed files with 2294 additions and 2347 deletions

View File

@@ -9,7 +9,7 @@ from dataclasses import dataclass
from enum import Enum, auto
from itertools import product
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Optional
import torch
import torch.utils.benchmark as TBenchmark
@@ -61,15 +61,15 @@ def make_rand_lora_weight_tensor(k: int,
def make_rand_tensors(
a_shape: Tuple[int],
b_shape: Tuple[int],
c_shape: Tuple[int],
a_shape: tuple[int],
b_shape: tuple[int],
c_shape: tuple[int],
a_dtype: torch.dtype,
b_dtype: torch.dtype,
c_dtype: torch.dtype,
num_slices: int,
device: str = "cuda",
) -> Tuple[torch.Tensor, List[torch.Tensor], torch.Tensor]:
) -> tuple[torch.Tensor, list[torch.Tensor], torch.Tensor]:
"""
Make LoRA input/output matrices.
"""
@@ -135,7 +135,7 @@ def make_token_lora_mapping(num_tokens: int, num_prompts: int,
def ref_group_gemm(ref_out: torch.Tensor, input: torch.Tensor,
lora_weights: List[torch.Tensor],
lora_weights: list[torch.Tensor],
seq_lens_cpu: torch.Tensor,
prompt_lora_mapping_cpu: torch.Tensor, scaling: float,
add_inputs: Optional[bool]):
@@ -204,7 +204,7 @@ class OpType(Enum):
def is_expand_slice_fn(self) -> bool:
return self in [OpType.BGMV_EXPAND_SLICE]
def num_slices(self) -> List[int]:
def num_slices(self) -> list[int]:
if self in [OpType.SGMV_EXPAND, OpType.SGMV_SHRINK]:
# SGMV kernels supports slices
return [1, 2, 3]
@@ -215,7 +215,7 @@ class OpType(Enum):
raise ValueError(f"Unrecognized OpType {self}")
def mkn(self, batch_size: int, seq_length: int, hidden_size: int,
lora_rank: int) -> Tuple[int, int, int]:
lora_rank: int) -> tuple[int, int, int]:
num_tokens = batch_size * seq_length
if self.is_shrink_fn():
m = num_tokens
@@ -230,7 +230,7 @@ class OpType(Enum):
def matmul_dtypes(
self, op_dtype: torch.dtype
) -> Tuple[torch.dtype, torch.dtype, torch.dtype]:
) -> tuple[torch.dtype, torch.dtype, torch.dtype]:
"""
return a type, b type and c type for A x B = C
"""
@@ -243,7 +243,7 @@ class OpType(Enum):
def matmul_shapes(
self, batch_size: int, seq_length: int, hidden_size: int,
lora_rank: int, num_loras: int,
num_slices: int) -> Tuple[Tuple[int], Tuple[int], Tuple[int]]:
num_slices: int) -> tuple[tuple[int], tuple[int], tuple[int]]:
"""
Given num_slices, return the shapes of the A, B, and C matrices
in A x B = C, for the op_type
@@ -268,7 +268,7 @@ class OpType(Enum):
def bench_fn(self) -> Callable:
def emulate_bgmv_expand_slice(kwargs_list: List[Dict[str, Any]]):
def emulate_bgmv_expand_slice(kwargs_list: list[dict[str, Any]]):
for x in kwargs_list:
bgmv_expand_slice(**x)
@@ -285,7 +285,7 @@ class OpType(Enum):
raise ValueError(f"Unrecognized optype {self}")
def run_ref_group_gemm(self, output: torch.Tensor, input: torch.Tensor,
lora_weights: List[torch.Tensor],
lora_weights: list[torch.Tensor],
**kwargs) -> Callable:
"""Each benchmark operation expected the input, lora_weights and outputs
in a slightly different format. Refer to self.matmul_shapes().
@@ -384,7 +384,7 @@ class BenchmarkTensors:
"""
# matmul tensors
input: torch.Tensor
lora_weights_lst: List[torch.Tensor]
lora_weights_lst: list[torch.Tensor]
output: torch.Tensor
# metadata tensors
seq_lens: torch.Tensor
@@ -469,7 +469,7 @@ class BenchmarkTensors:
for i in range(len(self.lora_weights_lst)):
self.lora_weights_lst[i] = to_device(self.lora_weights_lst[i])
def metadata(self) -> Tuple[int, int, int]:
def metadata(self) -> tuple[int, int, int]:
"""
Return num_seqs, num_tokens and max_seq_len
"""
@@ -505,7 +505,7 @@ class BenchmarkTensors:
self.seq_lens = seq_lens.to(dtype=self.seq_lens.dtype)
self.seq_start_loc = seq_start_loc.to(dtype=self.seq_start_loc.dtype)
def as_sgmv_shrink_kwargs(self) -> Dict[str, Any]:
def as_sgmv_shrink_kwargs(self) -> dict[str, Any]:
self.convert_to_sgmv_benchmark_tensors()
self.sanity_check()
self.to_device(self.input.device)
@@ -540,7 +540,7 @@ class BenchmarkTensors:
'scaling': 1.0,
}
def as_sgmv_expand_kwargs(self, add_inputs: bool) -> Dict[str, Any]:
def as_sgmv_expand_kwargs(self, add_inputs: bool) -> dict[str, Any]:
self.convert_to_sgmv_benchmark_tensors()
self.sanity_check()
@@ -578,7 +578,7 @@ class BenchmarkTensors:
'add_inputs': add_inputs,
}
def as_bgmv_shrink_kwargs(self) -> Dict[str, Any]:
def as_bgmv_shrink_kwargs(self) -> dict[str, Any]:
assert len(self.lora_weights_lst) == 1
self.to_device(self.input.device)
@@ -634,7 +634,7 @@ class BenchmarkTensors:
'add_inputs': add_inputs
}
def as_bgmv_expand_slice_kwargs(self, add_inputs: bool) -> Dict[str, Any]:
def as_bgmv_expand_slice_kwargs(self, add_inputs: bool) -> dict[str, Any]:
_, num_tokens, _, num_slices = self.metadata()
# Sanity check shapes
@@ -670,7 +670,7 @@ class BenchmarkTensors:
def bench_fn_kwargs(self,
op_type: OpType,
add_inputs: Optional[bool] = None) -> Dict[str, Any]:
add_inputs: Optional[bool] = None) -> dict[str, Any]:
if op_type.is_shrink_fn():
assert add_inputs is None
else:
@@ -734,7 +734,7 @@ def bench_optype(ctx: BenchmarkContext,
assert expand_fn_add_inputs is not None
# BenchmarkContext -> BenchmarkTensors
bench_tensors : List[BenchmarkTensors] = \
bench_tensors : list[BenchmarkTensors] = \
[BenchmarkTensors.make(ctx, op_type) for _ in range(arg_pool_size)]
for bt in bench_tensors:
bt.sanity_check()
@@ -746,7 +746,7 @@ def bench_optype(ctx: BenchmarkContext,
for bt in bench_tensors
])
# BenchmarkTensors -> Dict (kwargs)
# BenchmarkTensors -> dict (kwargs)
kwargs_list = [
bt.bench_fn_kwargs(op_type, add_inputs=expand_fn_add_inputs)
for bt in bench_tensors
@@ -841,7 +841,7 @@ def use_cuda_graph_recommendation() -> str:
"""
def print_timers(timers: List[TMeasurement],
def print_timers(timers: list[TMeasurement],
args: Optional[argparse.Namespace] = None):
compare = TBenchmark.Compare(timers)
compare.print()
@@ -861,7 +861,7 @@ def print_timers(timers: List[TMeasurement],
"small num_loras the goal should be to match the torch.mm numbers.")
def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]):
def run(args: argparse.Namespace, bench_ctxs: list[BenchmarkContext]):
if args.cuda_graph_nops is not None:
assert args.cuda_graph_nops > 0
@@ -873,7 +873,7 @@ def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]):
timers = []
for bench_ctx in bench_ctxs:
for seq_len in args.seq_lengths:
bench_ops: List[OpType] = []
bench_ops: list[OpType] = []
if seq_len == 1:
# bench all decode ops
bench_ops = [op for op in args.op_types if op.is_decode_op()]
@@ -921,10 +921,10 @@ def run(args: argparse.Namespace, bench_ctxs: List[BenchmarkContext]):
pickle.dump(timers, f)
def as_benchmark_contexts(hidden_sizes: List[int], lora_ranks: List[int],
args: argparse.Namespace) -> List[BenchmarkContext]:
def as_benchmark_contexts(hidden_sizes: list[int], lora_ranks: list[int],
args: argparse.Namespace) -> list[BenchmarkContext]:
ctxs: List[BenchmarkContext] = []
ctxs: list[BenchmarkContext] = []
for batch_size, hidden_size, lora_rank, num_loras, sort_by_lora_id in product( # noqa
args.batch_sizes, list(hidden_sizes), lora_ranks, args.num_loras,
args.sort_by_lora_id):
@@ -954,7 +954,7 @@ def run_list_bench(args: argparse.Namespace):
f" LoRA Ranks {args.lora_ranks}")
# Get all benchmarking contexts
bench_contexts: List[BenchmarkContext] = as_benchmark_contexts(
bench_contexts: list[BenchmarkContext] = as_benchmark_contexts(
hidden_sizes=args.hidden_sizes, lora_ranks=args.lora_ranks, args=args)
run(args, bench_contexts)
@@ -975,7 +975,7 @@ def run_range_bench(args: argparse.Namespace):
f" LoRA Ranks {lora_ranks}")
# Get all benchmarking contexts
bench_contexts: List[BenchmarkContext] = as_benchmark_contexts(
bench_contexts: list[BenchmarkContext] = as_benchmark_contexts(
hidden_sizes=hidden_sizes, lora_ranks=lora_ranks, args=args)
run(args, bench_contexts)
@@ -1002,7 +1002,7 @@ def run_model_bench(args: argparse.Namespace):
f" LoRA Ranks {args.lora_ranks}")
# Get all benchmarking contexts
bench_contexts: List[BenchmarkContext] = as_benchmark_contexts(
bench_contexts: list[BenchmarkContext] = as_benchmark_contexts(
hidden_sizes=hidden_sizes, lora_ranks=args.lora_ranks, args=args)
run(args, bench_contexts)