diff --git a/tools/pre_commit/mypy.py b/tools/pre_commit/mypy.py index 4bc0b3ad4..27312ac59 100755 --- a/tools/pre_commit/mypy.py +++ b/tools/pre_commit/mypy.py @@ -43,7 +43,6 @@ EXCLUDE = [ "vllm/benchmarks", "vllm/config", "vllm/device_allocator", - "vllm/profiler", "vllm/reasoning", "vllm/tool_parser", ] diff --git a/vllm/profiler/layerwise_profile.py b/vllm/profiler/layerwise_profile.py index 6b4348b96..a36e4611f 100644 --- a/vllm/profiler/layerwise_profile.py +++ b/vllm/profiler/layerwise_profile.py @@ -5,7 +5,7 @@ import copy from collections import defaultdict from collections.abc import Callable from dataclasses import asdict, dataclass, field -from typing import Any, TypeAlias +from typing import Any, Generic, TypeAlias, TypeVar from torch._C._autograd import DeviceType, _KinetoEvent, _ProfilerResult from torch._C._profiler import _EventType, _ExperimentalConfig, _ProfilerEvent @@ -69,13 +69,14 @@ class ModelStatsEntry: StatsEntry: TypeAlias = ModelStatsEntry | SummaryStatsEntry +StatsEntryT = TypeVar("StatsEntryT", bound=StatsEntry) @dataclass -class _StatsTreeNode: - entry: StatsEntry - children: list[StatsEntry] - parent: StatsEntry | None +class _StatsTreeNode(Generic[StatsEntryT]): + entry: StatsEntryT + children: list["_StatsTreeNode[StatsEntryT]"] = field(default_factory=list) + parent: "_StatsTreeNode[StatsEntryT] | None" = None @dataclass @@ -84,8 +85,8 @@ class LayerwiseProfileResults(profile): _kineto_event_correlation_map: dict[int, list[_KinetoEvent]] = field(init=False) _event_correlation_map: dict[int, list[FunctionEvent]] = field(init=False) _module_tree: list[_ModuleTreeNode] = field(init=False) - _model_stats_tree: list[_StatsTreeNode] = field(init=False) - _summary_stats_tree: list[_StatsTreeNode] = field(init=False) + _model_stats_tree: list[_StatsTreeNode[ModelStatsEntry]] = field(init=False) + _summary_stats_tree: list[_StatsTreeNode[SummaryStatsEntry]] = field(init=False) # profile metadata num_running_seqs: int | None = None @@ -95,7 +96,7 @@ class LayerwiseProfileResults(profile): self._build_module_tree() self._build_stats_trees() - def print_model_table(self, column_widths: dict[str, int] = None): + def print_model_table(self, column_widths: dict[str, int] | None = None): _column_widths = dict( name=60, cpu_time_us=12, cuda_time_us=12, pct_cuda_time=12, trace=60 ) @@ -113,7 +114,7 @@ class LayerwiseProfileResults(profile): ) ) - def print_summary_table(self, column_widths: dict[str, int] = None): + def print_summary_table(self, column_widths: dict[str, int] | None = None): _column_widths = dict( name=80, cuda_time_us=12, pct_cuda_time=12, invocations=15 ) @@ -155,14 +156,14 @@ class LayerwiseProfileResults(profile): @staticmethod def _indent_row_names_based_on_depth( - depths_rows: list[tuple[int, StatsEntry]], + depths_rows: list[tuple[int, StatsEntryT]], indent_style: Callable[[int], str] | str = " ", ): - indented_rows = [] + indented_rows: list[StatsEntryT] = [] for depth, row in depths_rows: if row.cuda_time_us == 0: continue - indented_row = copy.deepcopy(row) + indented_row: StatsEntryT = copy.deepcopy(row) indented_row.name = indent_string(indented_row.name, depth, indent_style) indented_rows.append(indented_row) return indented_rows @@ -240,7 +241,7 @@ class LayerwiseProfileResults(profile): return sum([self._cumulative_cuda_time(root) for root in self._module_tree]) def _build_stats_trees(self): - summary_dict: dict[str, _StatsTreeNode] = {} + summary_dict: dict[tuple[str, ...], _StatsTreeNode[SummaryStatsEntry]] = {} total_cuda_time = self._total_cuda_time() def pct_cuda_time(cuda_time_us): @@ -248,9 +249,9 @@ class LayerwiseProfileResults(profile): def build_summary_stats_tree_df( node: _ModuleTreeNode, - parent: _StatsTreeNode | None = None, - summary_trace: tuple[str] = (), - ): + parent: _StatsTreeNode[SummaryStatsEntry] | None = None, + summary_trace: tuple[str, ...] = (), + ) -> _StatsTreeNode[SummaryStatsEntry] | None: if event_has_module(node.event): name = event_module_repr(node.event) cuda_time_us = self._cumulative_cuda_time(node) @@ -274,7 +275,6 @@ class LayerwiseProfileResults(profile): pct_cuda_time=pct_cuda_time(cuda_time_us), invocations=1, ), - children=[], parent=parent, ) if parent: @@ -290,11 +290,14 @@ class LayerwiseProfileResults(profile): self._summary_stats_tree = [] for root in self._module_tree: - self._summary_stats_tree.append(build_summary_stats_tree_df(root)) + summary_node = build_summary_stats_tree_df(root) + if summary_node is not None: + self._summary_stats_tree.append(summary_node) def build_model_stats_tree_df( - node: _ModuleTreeNode, parent: _StatsTreeNode | None = None - ): + node: _ModuleTreeNode, + parent: _StatsTreeNode[ModelStatsEntry] | None = None, + ) -> _StatsTreeNode[ModelStatsEntry] | None: if event_has_module( node.event, ): @@ -319,7 +322,6 @@ class LayerwiseProfileResults(profile): trace=trace, ), parent=parent, - children=[], ) if parent: parent.children.append(new_node) @@ -331,14 +333,16 @@ class LayerwiseProfileResults(profile): self._model_stats_tree = [] for root in self._module_tree: - self._model_stats_tree.append(build_model_stats_tree_df(root)) + model_node = build_model_stats_tree_df(root) + if model_node is not None: + self._model_stats_tree.append(model_node) def _flatten_stats_tree( - self, tree: list[_StatsTreeNode] - ) -> list[tuple[int, StatsEntry]]: - entries: list[tuple[int, StatsEntry]] = [] + self, tree: list[_StatsTreeNode[StatsEntryT]] + ) -> list[tuple[int, StatsEntryT]]: + entries: list[tuple[int, StatsEntryT]] = [] - def df_traversal(node: _StatsTreeNode, depth=0): + def df_traversal(node: _StatsTreeNode[StatsEntryT], depth: int = 0): entries.append((depth, node.entry)) for child in node.children: df_traversal(child, depth=depth + 1) @@ -348,10 +352,14 @@ class LayerwiseProfileResults(profile): return entries - def _convert_stats_tree_to_dict(self, tree: list[_StatsTreeNode]) -> list[dict]: - root_dicts: list[dict] = [] + def _convert_stats_tree_to_dict( + self, tree: list[_StatsTreeNode[StatsEntryT]] + ) -> list[dict[str, Any]]: + root_dicts: list[dict[str, Any]] = [] - def df_traversal(node: _StatsTreeNode, curr_json_list: list[dict]): + def df_traversal( + node: _StatsTreeNode[StatsEntryT], curr_json_list: list[dict[str, Any]] + ): curr_json_list.append({"entry": asdict(node.entry), "children": []}) for child in node.children: df_traversal(child, curr_json_list[-1]["children"])