Files
DeepGEMM/deep_gemm/testing/bench.py
2025-11-21 17:49:47 +08:00

138 lines
4.6 KiB
Python

import os
import sys
import torch
def bench(fn, num_warmups: int = 5, num_tests: int = 10,
high_precision: bool = False):
# Flush L2 cache with 256 MB data
torch.cuda.synchronize()
cache = torch.empty(int(256e6 // 4), dtype=torch.int, device='cuda')
cache.zero_()
# Warmup
for _ in range(num_warmups):
fn()
# Add a large kernel to eliminate the CPU launch overhead
if high_precision:
x = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
y = torch.randn((8192, 8192), dtype=torch.float, device='cuda')
x @ y
# Testing
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
for i in range(num_tests):
fn()
end_event.record()
torch.cuda.synchronize()
return start_event.elapsed_time(end_event) / num_tests / 1e3
class empty_suppress:
def __enter__(self):
return self
def __exit__(self, *_):
pass
class suppress_stdout_stderr:
def __enter__(self):
self.outnull_file = open(os.devnull, 'w')
self.errnull_file = open(os.devnull, 'w')
self.old_stdout_fileno_undup = sys.stdout.fileno()
self.old_stderr_fileno_undup = sys.stderr.fileno()
self.old_stdout_fileno = os.dup(sys.stdout.fileno())
self.old_stderr_fileno = os.dup(sys.stderr.fileno())
self.old_stdout = sys.stdout
self.old_stderr = sys.stderr
os.dup2(self.outnull_file.fileno(), self.old_stdout_fileno_undup)
os.dup2(self.errnull_file.fileno(), self.old_stderr_fileno_undup)
sys.stdout = self.outnull_file
sys.stderr = self.errnull_file
return self
def __exit__(self, *_):
sys.stdout = self.old_stdout
sys.stderr = self.old_stderr
os.dup2(self.old_stdout_fileno, self.old_stdout_fileno_undup)
os.dup2(self.old_stderr_fileno, self.old_stderr_fileno_undup)
os.close(self.old_stdout_fileno)
os.close(self.old_stderr_fileno)
self.outnull_file.close()
self.errnull_file.close()
def bench_kineto(fn, kernel_names, num_tests: int = 30,
suppress_kineto_output: bool = False,
trace_path: str = None, flush_l2: bool = True,
with_multiple_kernels: bool = False):
assert isinstance(kernel_names, str) or isinstance(kernel_names, tuple)
is_tuple = isinstance(kernel_names, tuple)
# Skip profiling
# Conflict with Nsight Systems, Nsight Compute and Compute Sanitizer
if int(os.environ.get('DG_USE_NVIDIA_TOOLS', 0)):
return (1, ) * len(kernel_names) if is_tuple else 1
# By default, flush L2 with an excessive 8 GB memset to give the GPU some (literal) chill time without full idle
flush_l2_size = int(8e9 // 4)
# For some auto-tuning kernels with prints
fn()
# Profile
suppress = suppress_stdout_stderr if suppress_kineto_output else empty_suppress
with suppress():
schedule = torch.profiler.schedule(wait=1, warmup=0, active=1, repeat=1)
profiler = torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA], schedule=schedule)
with profiler:
for i in range(2):
for _ in range(num_tests):
if flush_l2:
torch.empty(flush_l2_size, dtype=torch.int, device='cuda').zero_()
fn()
profiler.step()
# Parse the profiling table
prof_lines = profiler.key_averages().table(sort_by='cuda_time_total', max_name_column_width=100).split('\n')
kernel_names = (kernel_names, ) if isinstance(kernel_names, str) else kernel_names
if not with_multiple_kernels:
for name in kernel_names:
assert sum([name in line for line in prof_lines]) <= 1, f'Errors of the kernel {name} in the profiling table'
# Save chrome traces
if trace_path is not None:
profiler.export_chrome_trace(trace_path)
# Return average kernel times
units = {'ms': 1e3, 'us': 1e6}
kernel_times = []
for name in kernel_names:
total_time = 0
total_num = 0
for line in prof_lines:
if name in line:
time_str = line.split()[-2]
num_str = line.split()[-1]
for unit, scale in units.items():
if unit in time_str:
total_time += float(time_str.replace(unit, '')) / scale * int(num_str)
total_num += int(num_str)
break
kernel_times.append(total_time / total_num if total_num > 0 else 0)
return tuple(kernel_times) if is_tuple else kernel_times[0]