diff --git a/dsv4/kernels/cuda/test_fp8_gemm_tmem_read.cu b/dsv4/kernels/cuda/test_fp8_gemm_tmem_read.cu index 4253aacb..973b48ef 100644 --- a/dsv4/kernels/cuda/test_fp8_gemm_tmem_read.cu +++ b/dsv4/kernels/cuda/test_fp8_gemm_tmem_read.cu @@ -172,11 +172,17 @@ test_fp8_gemm_tmem_read_kernel( } // Row group 32-63 + // Try different TMEM strides to find the correct offset float tmp2[8] = {}; + // The TMEM layout for UMMA output may use stride = N/8 per row group + // For N=128, stride = 16. But let's try SK_TILE (128) first. + // Empirically: tb + col_base gives rows 0-31 correctly. + // We need to find where rows 32-63 are. + // Try: tb + (SK_TILE / 8) + col_base = tb + 16 + col_base asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" : "=f"(tmp2[0]),"=f"(tmp2[1]),"=f"(tmp2[2]),"=f"(tmp2[3]), "=f"(tmp2[4]),"=f"(tmp2[5]),"=f"(tmp2[6]),"=f"(tmp2[7]) - : "r"(tb + SK_TILE + col_base)); + : "r"(tb + (SK_TILE / 8) + col_base)); asm volatile("tcgen05.wait::ld.sync.aligned;" ::: "memory"); if (lane < n_ih - 32 && lane < 32) { int h = lane + 32;