B2 TMEM debug: try stride=SK_TILE/8=16 for row group 32-63

This commit is contained in:
2026-06-03 00:52:32 +00:00
parent fdf702470c
commit 8d0a02ca67

View File

@@ -172,11 +172,17 @@ test_fp8_gemm_tmem_read_kernel(
}
// Row group 32-63
// Try different TMEM strides to find the correct offset
float tmp2[8] = {};
// The TMEM layout for UMMA output may use stride = N/8 per row group
// For N=128, stride = 16. But let's try SK_TILE (128) first.
// Empirically: tb + col_base gives rows 0-31 correctly.
// We need to find where rows 32-63 are.
// Try: tb + (SK_TILE / 8) + col_base = tb + 16 + col_base
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
: "=f"(tmp2[0]),"=f"(tmp2[1]),"=f"(tmp2[2]),"=f"(tmp2[3]),
"=f"(tmp2[4]),"=f"(tmp2[5]),"=f"(tmp2[6]),"=f"(tmp2[7])
: "r"(tb + SK_TILE + col_base));
: "r"(tb + (SK_TILE / 8) + col_base));
asm volatile("tcgen05.wait::ld.sync.aligned;" ::: "memory");
if (lane < n_ih - 32 && lane < 32) {
int h = lane + 32;