diff --git a/tests/unit/test_tmem_4warp_read.cu b/tests/unit/test_tmem_4warp_read.cu index 3a3b8b97..c7ee6678 100644 --- a/tests/unit/test_tmem_4warp_read.cu +++ b/tests/unit/test_tmem_4warp_read.cu @@ -133,7 +133,7 @@ int main() { // SMEM: sbuf(8) + sRowMax(512) + align(128) + sQ0(4096) + sK0(4096) + slack(256) = 9000 size_t smem_off = 8 + 128*4; smem_off = ((smem_off + 127) & ~(size_t)127); - smem_off += TILE_SZ * 2 + 256; + smem_off += TILE_SZ * 2 * 2 + 256; // sQ0 + sK0 (each TILE_SZ BF16 = 4096 bytes) + slack int smem = (int)smem_off; test_mma_rows<<<1, 192, smem>>>(d_r);