From 2a765be7153dabcb6f9530ea036c56ec50d26846 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 08:44:55 +0000 Subject: [PATCH] fix: correct SMEM size for row-major (not swizzled) --- tests/unit/test_qk_mma.cu | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/tests/unit/test_qk_mma.cu b/tests/unit/test_qk_mma.cu index ce5ed323..7b10b973 100644 --- a/tests/unit/test_qk_mma.cu +++ b/tests/unit/test_qk_mma.cu @@ -64,14 +64,8 @@ int main() { cudaMemcpy(dk, hkb, B*s_k*HD*2, cudaMemcpyHostToDevice); cudaMemset(ds_out, 0, s_k*4); - // SMEM = 4 (tmem_base) + 128B align + 128*HD*2*2 (sQ + sK, swizzled) + slack - // Swizzled layout may use more space than row-major due to atom padding - // MN_SW128 atom = 1024*8 = 8192 BF16 per atom. For (128, HD) = 128*HD BF16. - // With 1 MN tile and ceil(HD/8) K tiles, total = 8192 * ceil(HD/8) BF16 - int atoms_n = (HD + 7) / 8; // number of K-tiles (atom has 8 BF16 in K dim) - int smem_q = 1024 * 8 * atoms_n; // BF16 elements for Q (padded) - int smem_k = 1024 * 8 * atoms_n; // BF16 elements for K (padded) - int smem = 4 + 128 + smem_q * 2 + smem_k * 2 + 1024; // bytes + // SMEM = 4 (tmem_base) + 128B align + 128*HD*2 (sQ) + 128*HD*2 (sK) + slack + int smem = 4 + 128 + 128 * HD * 2 + 128 * HD * 2 + 4096; dim3 grid(1, H, B); dim3 block(NTHREADS);