Update deprecated type hinting in vllm/profiler (#18057)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-13 12:34:34 +01:00
committed by GitHub
parent 6223dd8114
commit ff334ca1cd
3 changed files with 23 additions and 24 deletions

View File

@@ -3,7 +3,7 @@
import copy
from collections import defaultdict
from dataclasses import asdict, dataclass, field
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeAlias, Union
from typing import Any, Callable, Optional, TypeAlias, Union
import pandas as pd
from torch._C._autograd import DeviceType, _KinetoEvent, _ProfilerResult
@@ -20,7 +20,7 @@ from vllm.profiler.utils import (TablePrinter, event_has_module,
class _ModuleTreeNode:
event: _ProfilerEvent
parent: Optional['_ModuleTreeNode'] = None
children: List['_ModuleTreeNode'] = field(default_factory=list)
children: list['_ModuleTreeNode'] = field(default_factory=list)
trace: str = ""
@property
@@ -60,19 +60,19 @@ StatsEntry: TypeAlias = Union[ModelStatsEntry, SummaryStatsEntry]
@dataclass
class _StatsTreeNode:
entry: StatsEntry
children: List[StatsEntry]
children: list[StatsEntry]
parent: Optional[StatsEntry]
@dataclass
class LayerwiseProfileResults(profile):
_kineto_results: _ProfilerResult
_kineto_event_correlation_map: Dict[int,
List[_KinetoEvent]] = field(init=False)
_event_correlation_map: Dict[int, List[FunctionEvent]] = field(init=False)
_module_tree: List[_ModuleTreeNode] = field(init=False)
_model_stats_tree: List[_StatsTreeNode] = field(init=False)
_summary_stats_tree: List[_StatsTreeNode] = field(init=False)
_kineto_event_correlation_map: dict[int,
list[_KinetoEvent]] = field(init=False)
_event_correlation_map: dict[int, list[FunctionEvent]] = field(init=False)
_module_tree: list[_ModuleTreeNode] = field(init=False)
_model_stats_tree: list[_StatsTreeNode] = field(init=False)
_summary_stats_tree: list[_StatsTreeNode] = field(init=False)
# profile metadata
num_running_seqs: Optional[int] = None
@@ -82,7 +82,7 @@ class LayerwiseProfileResults(profile):
self._build_module_tree()
self._build_stats_trees()
def print_model_table(self, column_widths: Dict[str, int] = None):
def print_model_table(self, column_widths: dict[str, int] = None):
_column_widths = dict(name=60,
cpu_time_us=12,
cuda_time_us=12,
@@ -100,7 +100,7 @@ class LayerwiseProfileResults(profile):
filtered_model_table,
indent_style=lambda indent: "|" + "-" * indent + " "))
def print_summary_table(self, column_widths: Dict[str, int] = None):
def print_summary_table(self, column_widths: dict[str, int] = None):
_column_widths = dict(name=80,
cuda_time_us=12,
pct_cuda_time=12,
@@ -142,7 +142,7 @@ class LayerwiseProfileResults(profile):
}
@staticmethod
def _indent_row_names_based_on_depth(depths_rows: List[Tuple[int,
def _indent_row_names_based_on_depth(depths_rows: list[tuple[int,
StatsEntry]],
indent_style: Union[Callable[[int],
str],
@@ -229,7 +229,7 @@ class LayerwiseProfileResults(profile):
[self._cumulative_cuda_time(root) for root in self._module_tree])
def _build_stats_trees(self):
summary_dict: Dict[str, _StatsTreeNode] = {}
summary_dict: dict[str, _StatsTreeNode] = {}
total_cuda_time = self._total_cuda_time()
def pct_cuda_time(cuda_time_us):
@@ -238,7 +238,7 @@ class LayerwiseProfileResults(profile):
def build_summary_stats_tree_df(
node: _ModuleTreeNode,
parent: Optional[_StatsTreeNode] = None,
summary_trace: Tuple[str] = ()):
summary_trace: tuple[str] = ()):
if event_has_module(node.event):
name = event_module_repr(node.event)
@@ -313,8 +313,8 @@ class LayerwiseProfileResults(profile):
self._model_stats_tree.append(build_model_stats_tree_df(root))
def _flatten_stats_tree(
self, tree: List[_StatsTreeNode]) -> List[Tuple[int, StatsEntry]]:
entries: List[Tuple[int, StatsEntry]] = []
self, tree: list[_StatsTreeNode]) -> list[tuple[int, StatsEntry]]:
entries: list[tuple[int, StatsEntry]] = []
def df_traversal(node: _StatsTreeNode, depth=0):
entries.append((depth, node.entry))
@@ -327,10 +327,10 @@ class LayerwiseProfileResults(profile):
return entries
def _convert_stats_tree_to_dict(self,
tree: List[_StatsTreeNode]) -> List[Dict]:
root_dicts: List[Dict] = []
tree: list[_StatsTreeNode]) -> list[dict]:
root_dicts: list[dict] = []
def df_traversal(node: _StatsTreeNode, curr_json_list: List[Dict]):
def df_traversal(node: _StatsTreeNode, curr_json_list: list[dict]):
curr_json_list.append({
"entry": asdict(node.entry),
"children": []