test: 32 TMEM cols, add MMA call with N=32, read S from TMEM
This commit is contained in:
@@ -66,7 +66,7 @@ test_umma_qk_hd16(
|
||||
// ================================================================
|
||||
if (wid == 0) {
|
||||
uint32_t smem_ptr = __cvta_generic_to_shared(sTmemBase);
|
||||
tmem_alloc(smem_ptr, 128);
|
||||
tmem_alloc(smem_ptr, 32); // 32 columns like minimal test
|
||||
}
|
||||
__syncthreads();
|
||||
uint32_t tmem_base = *sTmemBase;
|
||||
@@ -78,14 +78,14 @@ test_umma_qk_hd16(
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Zero TMEM — test 2 stores with syncwarp between them
|
||||
if (wid == 0) {
|
||||
tmem_store(tmem_base, 0, 0, 0, 0);
|
||||
__syncwarp();
|
||||
tmem_store(tmem_base + 1, 0, 0, 0, 0);
|
||||
__syncwarp();
|
||||
}
|
||||
__syncthreads();
|
||||
// Zero TMEM — skip (minimal test didn't zero)
|
||||
// if (wid == 0) {
|
||||
// tmem_store(tmem_base, 0, 0, 0, 0);
|
||||
// __syncwarp();
|
||||
// tmem_store(tmem_base + 1, 0, 0, 0, 0);
|
||||
// __syncwarp();
|
||||
// }
|
||||
// __syncthreads();
|
||||
|
||||
if (tid == 0) {
|
||||
s_out[140] = 2.0f; // sentinel: survived tmem_store
|
||||
@@ -136,13 +136,30 @@ test_umma_qk_hd16(
|
||||
// __syncthreads();
|
||||
|
||||
// ================================================================
|
||||
// Read S from TMEM — skipped (no TMEM allocated)
|
||||
// Call tcgen05.mma SS — test with 32 columns (N=32)
|
||||
// ================================================================
|
||||
// Rebuild idesc for N=32
|
||||
uint32_t idesc = make_idesc(128, 32);
|
||||
|
||||
if (tid == 0) {
|
||||
umma_ss_f16(tmem_base, desc_q, desc_k, idesc, /*accumulate=*/false);
|
||||
}
|
||||
__syncwarp();
|
||||
if (wid == 0 && lane == 0) {
|
||||
tmem_fence_store();
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// ================================================================
|
||||
// Read S from TMEM
|
||||
// ================================================================
|
||||
// Just compute scalar reference and output descriptor debug info
|
||||
if (wid == 0) {
|
||||
// Placeholder: write zeros
|
||||
for (int col = 0; col < 128; col++) {
|
||||
if (lane == 0) s_out[col] = 0.0f;
|
||||
for (int col = 0; col < 32; col++) {
|
||||
uint32_t u0, u1, u2, u3;
|
||||
tmem_load(tmem_base + col, u0, u1, u2, u3);
|
||||
if (lane == 0) {
|
||||
s_out[col] = u32_to_f32(u0); // S[0, col]
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
@@ -163,7 +180,7 @@ test_umma_qk_hd16(
|
||||
|
||||
// TMEM dealloc
|
||||
if (wid == 0) {
|
||||
tmem_dealloc(tmem_base, 128);
|
||||
tmem_dealloc(tmem_base, 32);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user