Update deprecated Python 3.8 typing (#13971)

This commit is contained in:
Harry Mellor
2025-03-03 01:34:51 +00:00
committed by GitHub
parent bf33700ecd
commit cf069aa8aa
300 changed files with 2294 additions and 2347 deletions

View File

@@ -5,8 +5,9 @@ import json
import os
import sys
from argparse import RawTextHelpFormatter
from collections.abc import Generator
from dataclasses import asdict, dataclass
from typing import Any, Dict, Generator, List, Optional, TypeAlias
from typing import Any, Optional, TypeAlias
import torch
import tqdm
@@ -42,8 +43,8 @@ def get_dtype(dtype: str):
return dtype
OutputLen_NumReqs_Map: TypeAlias = Dict[int, int]
def compute_request_output_lengths(batch_size: int, step_requests: List[int]) \
OutputLen_NumReqs_Map: TypeAlias = dict[int, int]
def compute_request_output_lengths(batch_size: int, step_requests: list[int]) \
-> OutputLen_NumReqs_Map:
"""
Given the number of requests, batch_size, and the number of requests
@@ -63,7 +64,7 @@ def compute_request_output_lengths(batch_size: int, step_requests: List[int]) \
Args:
batch_size (int): Number of requests submitted for profile. This is
args.batch_size.
step_requests (List[int]): step_requests[i] is the number of requests
step_requests (list[int]): step_requests[i] is the number of requests
that the ith engine step should process.
Returns:
@@ -114,7 +115,7 @@ def compute_request_output_lengths(batch_size: int, step_requests: List[int]) \
return ol_nr
def determine_requests_per_step(context: ProfileContext) -> List[int]:
def determine_requests_per_step(context: ProfileContext) -> list[int]:
"""
Determine number of requests each engine step should process.
If context.num_steps is set, then all engine steps process the
@@ -130,7 +131,7 @@ def determine_requests_per_step(context: ProfileContext) -> List[int]:
context: ProfileContext object.
Returns:
List[int]: Number of requests to process for all engine-steps.
list[int]: Number of requests to process for all engine-steps.
output[i], contains the number of requests that the ith step
should process.
"""
@@ -170,7 +171,7 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
for key, value in asdict(context).items():
print(f" {key} = {value}")
requests_per_step: List[int] = determine_requests_per_step(context)
requests_per_step: list[int] = determine_requests_per_step(context)
ol_nr: OutputLen_NumReqs_Map = compute_request_output_lengths(
context.batch_size, requests_per_step)