diff --git a/tests/unit/test_fmha_6warp_multirow.cu b/tests/unit/test_fmha_6warp_multirow.cu index 160c4a16..2b9cfef3 100644 --- a/tests/unit/test_fmha_6warp_multirow.cu +++ b/tests/unit/test_fmha_6warp_multirow.cu @@ -120,9 +120,21 @@ static int test_single_T(int T, int n_h = 1, int batch = 1) { params.lse_head_stride = T; params.lse_batch_stride = n_h * T; - // SMEM: sTmemBase(8) + sRowMax(512) + sRowSum(512) + padding(128) + sQ0(4096) + sK0(4096) + sPk(4096) + sV(512) + slack(256) - // NO s_p_vals buffer — P is streamed per K-tile through sPk - int smem = 8 + MAX_T*4 + MAX_T*4 + 128 + TILE_SZ*2 + TILE_SZ*2 + V_SUB_SZ*2 + 256; + // SMEM: must match kernel's layout exactly. + // sTmemBase(8) + sRowMax(512) + sRowSum(512) + align128 + sQ0(4096) + sK0(4096) + align128 + sPk(4096) + align128 + sV(512) + slack + // Pessimistic: just compute with alignment + size_t smem_off = 0; + smem_off += 8; // sTmemBase + alignment + smem_off += 128 * sizeof(float) * 2; // sRowMax + sRowSum + smem_off = (smem_off + 127) & ~(size_t)127; // align for sQ0 + smem_off += 128 * MMA_K_BF16 * sizeof(bf16_t); // sQ0 + smem_off += 128 * MMA_K_BF16 * sizeof(bf16_t); // sK0 + smem_off = (smem_off + 127) & ~(size_t)127; // align for sPk + smem_off += 128 * MMA_K_BF16 * sizeof(bf16_t); // sPk + smem_off = (smem_off + 127) & ~(size_t)127; // align for sV + smem_off += 16 * MMA_K_BF16 * sizeof(bf16_t); // sV + smem_off += 256; // slack + int smem = (int)smem_off; smem = (smem + 127) & ~127; if (smem > 48 * 1024) {