Make various updates and fixes (#198)
This commit is contained in:
@@ -91,7 +91,7 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30,
|
||||
# Profile
|
||||
suppress = suppress_stdout_stderr if suppress_kineto_output and not using_nsys else empty_suppress
|
||||
with suppress():
|
||||
schedule = torch.profiler.schedule(wait=0, warmup=1, active=1, repeat=1) if not using_nsys else None
|
||||
schedule = torch.profiler.schedule(wait=1, warmup=0, active=1, repeat=1) if not using_nsys else None
|
||||
profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule) if not using_nsys else empty_suppress()
|
||||
with profiler:
|
||||
for i in range(2):
|
||||
@@ -112,10 +112,9 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30,
|
||||
is_tuple = isinstance(kernel_names, tuple)
|
||||
prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
|
||||
kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
|
||||
assert all([isinstance(name, str) for name in kernel_names])
|
||||
if not with_multiple_kernels:
|
||||
for name in kernel_names:
|
||||
assert sum([name in line for line in prof_lines]) == 1, f'Errors of the kernel {name} in the profiling table'
|
||||
assert sum([name in line for line in prof_lines]) <= 1, f'Errors of the kernel {name} in the profiling table'
|
||||
|
||||
# Save chrome traces
|
||||
if trace_path is not None:
|
||||
@@ -136,6 +135,6 @@ def bench_kineto(fn, kernel_names, num_tests: int = 30,
|
||||
total_time += float(time_str.replace(unit, '')) / scale * int(num_str)
|
||||
total_num += int(num_str)
|
||||
break
|
||||
kernel_times.append(total_time / total_num)
|
||||
kernel_times.append(total_time / total_num if total_num > 0 else 0)
|
||||
|
||||
return tuple(kernel_times) if is_tuple else kernel_times[0]
|
||||
|
||||
Reference in New Issue
Block a user