From bed67b234cf8341015dc6c5b99759bc4479a7f3e Mon Sep 17 00:00:00 2001 From: sazc Date: Mon, 10 Mar 2025 13:02:02 +0800 Subject: [PATCH] Minor fix --- deep_gemm/jit_kernels/gemm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index 33336a8..082c8ee 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -105,8 +105,8 @@ def get_best_configs(m: int, n: int, k: int, num_groups: int, num_sms: int, # NOTES: less L2 cache usage and less GPU frequency drop num_waves = get_num_waves(best_block_m, best_block_n) num_min_sms = ceil_div(ceil_div(m, best_block_m) * ceil_div(n, best_block_n) * num_groups, num_waves) - num_min_sms = ceil_div(max(num_min_sms, num_sms - 8), 2) * 2 - assert num_min_sms <= num_sms + num_min_sms = ceil_div(max(num_min_sms, num_sms - 8), best_num_tma_multicast) * best_num_tma_multicast + assert num_min_sms <= num_sms and is_tma_multicast_legal(n, best_block_n, best_num_tma_multicast, num_min_sms) return num_min_sms, best_block_m, best_block_n, best_num_stages, best_num_tma_multicast, best_smem_size