[Bugfix] Fix incorrect types in LayerwiseProfileResults (#12196)
Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
@@ -1,7 +1,7 @@
|
|||||||
import copy
|
import copy
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from typing import Callable, Dict, List, Optional, Tuple, TypeAlias, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeAlias, Union
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from torch._C._autograd import DeviceType, _KinetoEvent, _ProfilerResult
|
from torch._C._autograd import DeviceType, _KinetoEvent, _ProfilerResult
|
||||||
@@ -128,7 +128,7 @@ class LayerwiseProfileResults(profile):
|
|||||||
])
|
])
|
||||||
df.to_csv(filename)
|
df.to_csv(filename)
|
||||||
|
|
||||||
def convert_stats_to_dict(self) -> str:
|
def convert_stats_to_dict(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"num_running_seqs": self.num_running_seqs
|
"num_running_seqs": self.num_running_seqs
|
||||||
@@ -227,7 +227,7 @@ class LayerwiseProfileResults(profile):
|
|||||||
[self._cumulative_cuda_time(root) for root in self._module_tree])
|
[self._cumulative_cuda_time(root) for root in self._module_tree])
|
||||||
|
|
||||||
def _build_stats_trees(self):
|
def _build_stats_trees(self):
|
||||||
summary_dict: Dict[str, self.StatsTreeNode] = {}
|
summary_dict: Dict[str, _StatsTreeNode] = {}
|
||||||
total_cuda_time = self._total_cuda_time()
|
total_cuda_time = self._total_cuda_time()
|
||||||
|
|
||||||
def pct_cuda_time(cuda_time_us):
|
def pct_cuda_time(cuda_time_us):
|
||||||
|
|||||||
Reference in New Issue
Block a user