B2 TMEM debug: try stride=SK_TILE/8=16 for row group 32-63
This commit is contained in:
@@ -172,11 +172,17 @@ test_fp8_gemm_tmem_read_kernel(
|
||||
}
|
||||
|
||||
// Row group 32-63
|
||||
// Try different TMEM strides to find the correct offset
|
||||
float tmp2[8] = {};
|
||||
// The TMEM layout for UMMA output may use stride = N/8 per row group
|
||||
// For N=128, stride = 16. But let's try SK_TILE (128) first.
|
||||
// Empirically: tb + col_base gives rows 0-31 correctly.
|
||||
// We need to find where rows 32-63 are.
|
||||
// Try: tb + (SK_TILE / 8) + col_base = tb + 16 + col_base
|
||||
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
|
||||
: "=f"(tmp2[0]),"=f"(tmp2[1]),"=f"(tmp2[2]),"=f"(tmp2[3]),
|
||||
"=f"(tmp2[4]),"=f"(tmp2[5]),"=f"(tmp2[6]),"=f"(tmp2[7])
|
||||
: "r"(tb + SK_TILE + col_base));
|
||||
: "r"(tb + (SK_TILE / 8) + col_base));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;" ::: "memory");
|
||||
if (lane < n_ih - 32 && lane < 32) {
|
||||
int h = lane + 32;
|
||||
|
||||
Reference in New Issue
Block a user