[Misc] Fix mypy errors in vllm/profiler and remove from exclude list (#34959)

Signed-off-by: Taneem Ibrahim <taneem.ibrahim@gmail.com>
This commit is contained in:
Taneem Ibrahim
2026-02-20 21:56:33 -06:00
committed by GitHub
parent ded333fb9b
commit d38cd3dde5
2 changed files with 37 additions and 30 deletions

View File

@@ -43,7 +43,6 @@ EXCLUDE = [
"vllm/benchmarks",
"vllm/config",
"vllm/device_allocator",
"vllm/profiler",
"vllm/reasoning",
"vllm/tool_parser",
]

View File

@@ -5,7 +5,7 @@ import copy
from collections import defaultdict
from collections.abc import Callable
from dataclasses import asdict, dataclass, field
from typing import Any, TypeAlias
from typing import Any, Generic, TypeAlias, TypeVar
from torch._C._autograd import DeviceType, _KinetoEvent, _ProfilerResult
from torch._C._profiler import _EventType, _ExperimentalConfig, _ProfilerEvent
@@ -69,13 +69,14 @@ class ModelStatsEntry:
StatsEntry: TypeAlias = ModelStatsEntry | SummaryStatsEntry
StatsEntryT = TypeVar("StatsEntryT", bound=StatsEntry)
@dataclass
class _StatsTreeNode:
entry: StatsEntry
children: list[StatsEntry]
parent: StatsEntry | None
class _StatsTreeNode(Generic[StatsEntryT]):
entry: StatsEntryT
children: list["_StatsTreeNode[StatsEntryT]"] = field(default_factory=list)
parent: "_StatsTreeNode[StatsEntryT] | None" = None
@dataclass
@@ -84,8 +85,8 @@ class LayerwiseProfileResults(profile):
_kineto_event_correlation_map: dict[int, list[_KinetoEvent]] = field(init=False)
_event_correlation_map: dict[int, list[FunctionEvent]] = field(init=False)
_module_tree: list[_ModuleTreeNode] = field(init=False)
_model_stats_tree: list[_StatsTreeNode] = field(init=False)
_summary_stats_tree: list[_StatsTreeNode] = field(init=False)
_model_stats_tree: list[_StatsTreeNode[ModelStatsEntry]] = field(init=False)
_summary_stats_tree: list[_StatsTreeNode[SummaryStatsEntry]] = field(init=False)
# profile metadata
num_running_seqs: int | None = None
@@ -95,7 +96,7 @@ class LayerwiseProfileResults(profile):
self._build_module_tree()
self._build_stats_trees()
def print_model_table(self, column_widths: dict[str, int] = None):
def print_model_table(self, column_widths: dict[str, int] | None = None):
_column_widths = dict(
name=60, cpu_time_us=12, cuda_time_us=12, pct_cuda_time=12, trace=60
)
@@ -113,7 +114,7 @@ class LayerwiseProfileResults(profile):
)
)
def print_summary_table(self, column_widths: dict[str, int] = None):
def print_summary_table(self, column_widths: dict[str, int] | None = None):
_column_widths = dict(
name=80, cuda_time_us=12, pct_cuda_time=12, invocations=15
)
@@ -155,14 +156,14 @@ class LayerwiseProfileResults(profile):
@staticmethod
def _indent_row_names_based_on_depth(
depths_rows: list[tuple[int, StatsEntry]],
depths_rows: list[tuple[int, StatsEntryT]],
indent_style: Callable[[int], str] | str = " ",
):
indented_rows = []
indented_rows: list[StatsEntryT] = []
for depth, row in depths_rows:
if row.cuda_time_us == 0:
continue
indented_row = copy.deepcopy(row)
indented_row: StatsEntryT = copy.deepcopy(row)
indented_row.name = indent_string(indented_row.name, depth, indent_style)
indented_rows.append(indented_row)
return indented_rows
@@ -240,7 +241,7 @@ class LayerwiseProfileResults(profile):
return sum([self._cumulative_cuda_time(root) for root in self._module_tree])
def _build_stats_trees(self):
summary_dict: dict[str, _StatsTreeNode] = {}
summary_dict: dict[tuple[str, ...], _StatsTreeNode[SummaryStatsEntry]] = {}
total_cuda_time = self._total_cuda_time()
def pct_cuda_time(cuda_time_us):
@@ -248,9 +249,9 @@ class LayerwiseProfileResults(profile):
def build_summary_stats_tree_df(
node: _ModuleTreeNode,
parent: _StatsTreeNode | None = None,
summary_trace: tuple[str] = (),
):
parent: _StatsTreeNode[SummaryStatsEntry] | None = None,
summary_trace: tuple[str, ...] = (),
) -> _StatsTreeNode[SummaryStatsEntry] | None:
if event_has_module(node.event):
name = event_module_repr(node.event)
cuda_time_us = self._cumulative_cuda_time(node)
@@ -274,7 +275,6 @@ class LayerwiseProfileResults(profile):
pct_cuda_time=pct_cuda_time(cuda_time_us),
invocations=1,
),
children=[],
parent=parent,
)
if parent:
@@ -290,11 +290,14 @@ class LayerwiseProfileResults(profile):
self._summary_stats_tree = []
for root in self._module_tree:
self._summary_stats_tree.append(build_summary_stats_tree_df(root))
summary_node = build_summary_stats_tree_df(root)
if summary_node is not None:
self._summary_stats_tree.append(summary_node)
def build_model_stats_tree_df(
node: _ModuleTreeNode, parent: _StatsTreeNode | None = None
):
node: _ModuleTreeNode,
parent: _StatsTreeNode[ModelStatsEntry] | None = None,
) -> _StatsTreeNode[ModelStatsEntry] | None:
if event_has_module(
node.event,
):
@@ -319,7 +322,6 @@ class LayerwiseProfileResults(profile):
trace=trace,
),
parent=parent,
children=[],
)
if parent:
parent.children.append(new_node)
@@ -331,14 +333,16 @@ class LayerwiseProfileResults(profile):
self._model_stats_tree = []
for root in self._module_tree:
self._model_stats_tree.append(build_model_stats_tree_df(root))
model_node = build_model_stats_tree_df(root)
if model_node is not None:
self._model_stats_tree.append(model_node)
def _flatten_stats_tree(
self, tree: list[_StatsTreeNode]
) -> list[tuple[int, StatsEntry]]:
entries: list[tuple[int, StatsEntry]] = []
self, tree: list[_StatsTreeNode[StatsEntryT]]
) -> list[tuple[int, StatsEntryT]]:
entries: list[tuple[int, StatsEntryT]] = []
def df_traversal(node: _StatsTreeNode, depth=0):
def df_traversal(node: _StatsTreeNode[StatsEntryT], depth: int = 0):
entries.append((depth, node.entry))
for child in node.children:
df_traversal(child, depth=depth + 1)
@@ -348,10 +352,14 @@ class LayerwiseProfileResults(profile):
return entries
def _convert_stats_tree_to_dict(self, tree: list[_StatsTreeNode]) -> list[dict]:
root_dicts: list[dict] = []
def _convert_stats_tree_to_dict(
self, tree: list[_StatsTreeNode[StatsEntryT]]
) -> list[dict[str, Any]]:
root_dicts: list[dict[str, Any]] = []
def df_traversal(node: _StatsTreeNode, curr_json_list: list[dict]):
def df_traversal(
node: _StatsTreeNode[StatsEntryT], curr_json_list: list[dict[str, Any]]
):
curr_json_list.append({"entry": asdict(node.entry), "children": []})
for child in node.children:
df_traversal(child, curr_json_list[-1]["children"])