Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -30,9 +30,9 @@ def trim_string_back(string, width):
|
||||
|
||||
|
||||
class TablePrinter:
|
||||
|
||||
def __init__(self, row_cls: type[dataclasses.dataclass],
|
||||
column_widths: dict[str, int]):
|
||||
def __init__(
|
||||
self, row_cls: type[dataclasses.dataclass], column_widths: dict[str, int]
|
||||
):
|
||||
self.row_cls = row_cls
|
||||
self.fieldnames = [x.name for x in dataclasses.fields(row_cls)]
|
||||
self.column_widths = column_widths
|
||||
@@ -46,16 +46,18 @@ class TablePrinter:
|
||||
|
||||
def _print_header(self):
|
||||
for i, f in enumerate(self.fieldnames):
|
||||
last = (i == len(self.fieldnames) - 1)
|
||||
last = i == len(self.fieldnames) - 1
|
||||
col_width = self.column_widths[f]
|
||||
print(trim_string_back(f, col_width).ljust(col_width),
|
||||
end=" | " if not last else "\n")
|
||||
print(
|
||||
trim_string_back(f, col_width).ljust(col_width),
|
||||
end=" | " if not last else "\n",
|
||||
)
|
||||
|
||||
def _print_row(self, row):
|
||||
assert isinstance(row, self.row_cls)
|
||||
|
||||
for i, f in enumerate(self.fieldnames):
|
||||
last = (i == len(self.fieldnames) - 1)
|
||||
last = i == len(self.fieldnames) - 1
|
||||
col_width = self.column_widths[f]
|
||||
val = getattr(row, f)
|
||||
|
||||
@@ -75,9 +77,9 @@ class TablePrinter:
|
||||
print("=" * (total_col_width + 3 * (len(self.column_widths) - 1)))
|
||||
|
||||
|
||||
def indent_string(string: str,
|
||||
indent: int,
|
||||
indent_style: Union[Callable[[int], str], str] = " ") -> str:
|
||||
def indent_string(
|
||||
string: str, indent: int, indent_style: Union[Callable[[int], str], str] = " "
|
||||
) -> str:
|
||||
if indent:
|
||||
if isinstance(indent_style, str):
|
||||
return indent_style * indent + string
|
||||
@@ -111,15 +113,14 @@ def event_arg_repr(arg) -> str:
|
||||
elif isinstance(arg, tuple):
|
||||
return f"({', '.join([event_arg_repr(x) for x in arg])})"
|
||||
else:
|
||||
assert isinstance(arg,
|
||||
_TensorMetadata), f"Unsupported type: {type(arg)}"
|
||||
sizes_str = ', '.join([str(x) for x in arg.sizes])
|
||||
assert isinstance(arg, _TensorMetadata), f"Unsupported type: {type(arg)}"
|
||||
sizes_str = ", ".join([str(x) for x in arg.sizes])
|
||||
return f"{str(arg.dtype).replace('torch.', '')}[{sizes_str}]"
|
||||
|
||||
|
||||
def event_torch_op_repr(event: _ProfilerEvent) -> str:
|
||||
assert event.tag == _EventType.TorchOp
|
||||
args_str = ', '.join([event_arg_repr(x) for x in event.typed[1].inputs])
|
||||
args_str = ", ".join([event_arg_repr(x) for x in event.typed[1].inputs])
|
||||
return f"{event.name}({args_str})".replace("aten::", "")
|
||||
|
||||
|
||||
@@ -127,15 +128,17 @@ def event_module_repr(event: _ProfilerEvent) -> str:
|
||||
assert event_has_module(event)
|
||||
module = event.typed[1].module
|
||||
if module.parameters and len(module.parameters) > 0:
|
||||
args_str = ', '.join(
|
||||
[f'{x[0]}={event_arg_repr(x[1])}' for x in module.parameters])
|
||||
args_str = ", ".join(
|
||||
[f"{x[0]}={event_arg_repr(x[1])}" for x in module.parameters]
|
||||
)
|
||||
return f"{module.cls_name}({args_str})"
|
||||
else:
|
||||
return module.cls_name
|
||||
|
||||
|
||||
def event_torch_op_stack_trace(curr_event: _ProfilerEvent,
|
||||
until: Callable[[_ProfilerEvent], bool]) -> str:
|
||||
def event_torch_op_stack_trace(
|
||||
curr_event: _ProfilerEvent, until: Callable[[_ProfilerEvent], bool]
|
||||
) -> str:
|
||||
trace = ""
|
||||
curr_event = curr_event.parent
|
||||
while curr_event and not until(curr_event):
|
||||
|
||||
Reference in New Issue
Block a user