From 3fc6728dee6515eb7702a0c94dfb7921b02d8527 Mon Sep 17 00:00:00 2001 From: fy1214 <93441374+fy1214@users.noreply.github.com> Date: Wed, 2 Jul 2025 14:05:36 +0800 Subject: [PATCH] [add] fix smem_barrier size in wgrad way (#122) --- deep_gemm/jit_kernels/gemm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deep_gemm/jit_kernels/gemm.py b/deep_gemm/jit_kernels/gemm.py index 574f821..2b0c67a 100644 --- a/deep_gemm/jit_kernels/gemm.py +++ b/deep_gemm/jit_kernels/gemm.py @@ -49,7 +49,7 @@ def get_smem_config(num_stages: int, k: int, block_m: int, block_n: int, block_k smem_b_per_stage = block_n * block_k smem_scales_b_per_stage = ceil_div(block_n * 4, block_k) * block_k if is_wgrad else 0 smem_scales_b = ceil_div(k, block_k) * 4 if not is_wgrad else 0 - smem_barrier = num_stages * 8 * 2 + smem_barrier = num_stages * 8 * 2 if not is_wgrad else (num_stages + 1) * 8 * 2 smem_size = 0 smem_size += smem_d