From f2124b93788886904ca8d5c8a84bd4def922b415 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 23:08:54 +0000 Subject: [PATCH] fix: SMEM calc in decode test --- tests/unit/test_fmha_6warp_multirow.cu | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unit/test_fmha_6warp_multirow.cu b/tests/unit/test_fmha_6warp_multirow.cu index 8554717a..2abcf2f6 100644 --- a/tests/unit/test_fmha_6warp_multirow.cu +++ b/tests/unit/test_fmha_6warp_multirow.cu @@ -197,6 +197,8 @@ int main() { cudaMemcpy(d_k, h_k, SK*HD*sizeof(bf16_t), cudaMemcpyHostToDevice); cudaMemcpy(d_v, h_v, HD*SK*sizeof(bf16_t), cudaMemcpyHostToDevice); + constexpr int TILE_SZ = 128 * MMA_K_BF16; + constexpr int V_SUB_SZ = 16 * MMA_K_BF16; size_t smem_off = 8 + 128*4 + 128*4; smem_off = ((smem_off + 127) & ~(size_t)127); smem_off += TILE_SZ + TILE_SZ;