/** * Systematic test: Can tcgen05.mma SS and TS coexist in the same kernel? * * The crash from combining QK (SS) + PV (TS) in the same kernel suggests * a TMEM state issue. This test systematically varies: * 1. TMEM allocation size * 2. Whether SS and TS share the same TMEM region or use separate regions * 3. Whether there's a TMEM dealloc+realloc between SS and TS * 4. Whether the issue is SS→TS specifically, or any second MMA call * * The working test_mma_ts.cu uses TS alone and works. test_fmha_hd64.cu * uses SS alone (with register-math PV) and works. The crash only happens * when both are in the same kernel. * * Key hypothesis: After tcgen05.mma SS, the TMEM C operand's internal * state may conflict with the TS MMA's A operand read. The SS MMA writes * to TMEM columns in Layout D format. The TS MMA reads A from TMEM and * expects it in a specific format. If the formats don't match, or if * there's a hardware fence missing between SS and TS, that could cause * the crash. * * Alternative hypothesis: The issue is simpler — TMEM column accounting. * SS writes 128 columns. TS reads from columns 0-15 (subset of P) and * writes to columns 128-143 (O region). Maybe the SS MMA reserves or * locks TMEM columns in a way that TS can't access. * * Test plan: * Phase 1: SS alone (baseline — should work) * Phase 2: TS alone (baseline — should work) * Phase 3: SS then TS, same TMEM (the crash case) * Phase 4: SS then TS, separate TMEM allocs (dealloc after SS, realloc for TS) * Phase 5: Two SS calls in sequence (does SS→SS work?) * Phase 6: Two TS calls in sequence (does TS→TS work?) * Phase 7: SS then TS with explicit TMEM barrier */ #include #include #include #include #include "dsv4/kernels/attention/fmha_common.cuh" #include "dsv4/kernels/attention/fmha_umma_desc.cuh" using namespace dsv4::kernels::attention; // ============================================================ // Helper: fill SMEM (16,16) canonical with all-1s // ============================================================ __device__ void fill_smem_16x16_ones(bf16_t* s, float val = 1.0f) { for (int i = threadIdx.x; i < 16 * 16; i += 128) s[i] = 0; __syncthreads(); for (int i = threadIdx.x; i < 16 * 16; i += 128) { int r = i / 16, c = i % 16; int ck = c / 8, lc = c % 8; int tmn = r / 8, lr = r % 8; s[ck * 2 * 64 + tmn * 64 + lr * 8 + lc] = f32_to_bf16(val); } __syncthreads(); } // ============================================================ // Helper: fill SMEM (128,16) canonical with all-1s (for Q/K) // ============================================================ __device__ void fill_smem_128x16_ones(bf16_t* s, float val = 1.0f) { constexpr int CORES_MN = 16, CORES_K = 2; for (int i = threadIdx.x; i < 128 * 16; i += 128) s[i] = 0; __syncthreads(); for (int i = threadIdx.x; i < 128 * 16; i += 128) { int r = i / 16, c = i % 16; int ck = c / 8, lc = c % 8; int tmn = r / 8, lr = r % 8; s[ck * CORES_MN * 64 + tmn * 64 + lr * 8 + lc] = f32_to_bf16(val); } __syncthreads(); } // ============================================================ // Helper: write A=all-1s to TMEM cols [col_start, col_start+16) // using 32x32b.x8 stores (warp-collective, all lanes write 1.0) // ============================================================ __device__ void tmem_write_ones_128x16(uint32_t tb, int col_start, float val = 1.0f) { for (int n = 0; n < 2; n++) { // 16 cols / 8 = 2 iterations float p0=val, p1=val, p2=val, p3=val, p4=val, p5=val, p6=val, p7=val; asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32 [%0],{%1,%2,%3,%4,%5,%6,%7,%8};" :: "r"(tb + col_start + n*8), "f"(p0),"f"(p1),"f"(p2),"f"(p3),"f"(p4),"f"(p5),"f"(p6),"f"(p7)); } } // ============================================================ // Helper: read 16 TMEM cols starting at col_start, print row 0 // ============================================================ __device__ void tmem_read_print_16(uint32_t tb, int col_start, const char* label) { float vals[16]; int wid = threadIdx.x / 32, lane = threadIdx.x % 32; for (int n = 0; n < 2; n++) { float tmp[8]; asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) : "r"(tb + col_start + n*8)); asm volatile("tcgen05.wait::ld.sync.aligned;"); if (lane == 0) for (int c=0;c<8;c++) vals[n*8+c] = tmp[c]; } if (lane == 0) { printf("%s: ", label); for (int c=0;c<16;c++) printf("%.1f ", vals[c]); printf("\n"); } } // ============================================================ // PHASE 1: SS alone // ============================================================ __global__ void __launch_bounds__(128) test_phase1_ss() { const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32; extern __shared__ char sbuf[]; uint32_t* sTmemBase = (uint32_t*)sbuf; bf16_t* sA = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15); bf16_t* sB = sA + 128 * 16; // Second (128,16) buffer for B fill_smem_128x16_ones(sA, 1.0f); fill_smem_128x16_ones(sB, 2.0f); if (wid == 0) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 128); __syncthreads(); uint32_t tb = *sTmemBase; // SS: A(128,16) × B(128,16) → C(128,128) at tb uint64_t da = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sA), 128); uint64_t db = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sB), 128); uint32_t idesc = make_idesc(128, 128); if (tid == 0) umma_ss_f16(tb, da, db, idesc, false); asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); __syncthreads(); // Read first 16 cols of S if (wid == 0) tmem_read_print_16(tb, 0, "Phase1 SS S[0,0..15]"); __syncthreads(); if (wid == 0) tmem_dealloc(tb, 128); } // ============================================================ // PHASE 2: TS alone // ============================================================ __global__ void __launch_bounds__(128) test_phase2_ts() { const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32; extern __shared__ char sbuf[]; uint32_t* sTmemBase = (uint32_t*)sbuf; bf16_t* sV = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15); fill_smem_16x16_ones(sV, 2.0f); if (wid == 0) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 64); __syncthreads(); uint32_t tb = *sTmemBase; // Write A = all 1.0 to cols 0-15 if (wid == 0) { tmem_write_ones_128x16(tb, 0, 1.0f); tmem_fence_store(); } __syncthreads(); // TS: A(128,16, TMEM) × B(16,16, SMEM) → C(128,16, TMEM) at tb+32 uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16); uint32_t idesc = make_idesc(128, 16); if (tid == 0) umma_ts_f16(tb + 32, tb, dv, idesc, false); asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); __syncthreads(); // Read C if (wid == 0) tmem_read_print_16(tb, 32, "Phase2 TS C[0,0..15]"); __syncthreads(); if (wid == 0) tmem_dealloc(tb, 64); } // ============================================================ // PHASE 3: SS then TS, same TMEM allocation // This is the crash case we need to debug. // ============================================================ __global__ void __launch_bounds__(128) test_phase3_ss_then_ts() { const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32; extern __shared__ char sbuf[]; uint32_t* sTmemBase = (uint32_t*)sbuf; bf16_t* sAB = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15); bf16_t* sV = (bf16_t*)(((uintptr_t)(sAB + 2 * 128 * 16) + 127) & ~(uintptr_t)127); // SS buffers fill_smem_128x16_ones(sAB, 1.0f); fill_smem_128x16_ones(sAB + 128 * 16, 2.0f); // TS buffer fill_smem_16x16_ones(sV, 2.0f); // TMEM: 256 columns. 0-127 = SS output (S). 128-143 = TS output (O). if (wid == 0) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 256); __syncthreads(); uint32_t tb = *sTmemBase; // STEP 1: SS QK GEMM → S at tb (cols 0-127) uint64_t da = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sAB), 128); uint64_t db = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sAB + 128 * 16), 128); uint32_t idesc_ss = make_idesc(128, 128); if (tid == 0) umma_ss_f16(tb, da, db, idesc_ss, false); asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); __syncthreads(); // Verify SS output if (wid == 0) tmem_read_print_16(tb, 0, "Phase3 SS S[0,0..15]"); // STEP 2: TS PV GEMM → O at tb+128 (cols 128-143) // A = first 16 cols of S (tb + 0) // B = sV (16, 16) // C = O at tb + 128 uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16); uint32_t idesc_ts = make_idesc(128, 16); if (tid == 0) umma_ts_f16(tb + 128, tb, dv, idesc_ts, false); asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); __syncthreads(); // Read TS output if (wid == 0) tmem_read_print_16(tb, 128, "Phase3 TS O[0,0..15]"); __syncthreads(); if (wid == 0) tmem_dealloc(tb, 256); } // ============================================================ // PHASE 4: SS then TS, with TMEM dealloc + realloc between them // Tests whether the crash is due to TMEM state from SS interfering with TS // ============================================================ __global__ void __launch_bounds__(128) test_phase4_ss_ts_separate_tmem() { const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32; extern __shared__ char sbuf[]; uint32_t* sTmemBase = (uint32_t*)sbuf; bf16_t* sAB = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15); bf16_t* sV = (bf16_t*)(((uintptr_t)(sAB + 2 * 128 * 16) + 127) & ~(uintptr_t)127); fill_smem_128x16_ones(sAB, 1.0f); fill_smem_128x16_ones(sAB + 128 * 16, 2.0f); fill_smem_16x16_ones(sV, 2.0f); // STEP 1: SS with 128-col TMEM if (wid == 0) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 128); __syncthreads(); uint32_t tb1 = *sTmemBase; uint64_t da = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sAB), 128); uint64_t db = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sAB + 128 * 16), 128); uint32_t idesc_ss = make_idesc(128, 128); if (tid == 0) umma_ss_f16(tb1, da, db, idesc_ss, false); asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); __syncthreads(); if (wid == 0) tmem_read_print_16(tb1, 0, "Phase4 SS S[0,0..15]"); // Save S values before dealloc (just row 0, first 16 cols) float saved_s[16] = {0}; if (wid == 0 && lane == 0) { for (int n = 0; n < 2; n++) { float tmp[8]; asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) : "r"(tb1 + n*8)); asm volatile("tcgen05.wait::ld.sync.aligned;"); for (int c=0;c<8;c++) saved_s[n*8+c] = tmp[c]; } } __syncthreads(); // Dealloc SS TMEM if (wid == 0) tmem_dealloc(tb1, 128); __syncthreads(); // STEP 2: Realloc TMEM for TS — 64 cols if (wid == 0) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 64); __syncthreads(); uint32_t tb2 = *sTmemBase; // Write A = all 1.0 to new TMEM cols 0-15 if (wid == 0) { tmem_write_ones_128x16(tb2, 0, 1.0f); tmem_fence_store(); } __syncthreads(); // TS MMA uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16); uint32_t idesc_ts = make_idesc(128, 16); if (tid == 0) umma_ts_f16(tb2 + 32, tb2, dv, idesc_ts, false); asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); __syncthreads(); if (wid == 0) tmem_read_print_16(tb2, 32, "Phase4 TS C[0,0..15]"); __syncthreads(); if (wid == 0) tmem_dealloc(tb2, 64); } // ============================================================ // PHASE 5: Two SS calls in sequence // Does SS→SS work? // ============================================================ __global__ void __launch_bounds__(128) test_phase5_ss_ss() { const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32; extern __shared__ char sbuf[]; uint32_t* sTmemBase = (uint32_t*)sbuf; bf16_t* sA1 = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15); bf16_t* sB1 = sA1 + 128 * 16; fill_smem_128x16_ones(sA1, 1.0f); fill_smem_128x16_ones(sB1, 2.0f); if (wid == 0) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 256); __syncthreads(); uint32_t tb = *sTmemBase; // SS #1 → cols 0-127 uint64_t da1 = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sA1), 128); uint64_t db1 = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sB1), 128); uint32_t idesc1 = make_idesc(128, 128); if (tid == 0) umma_ss_f16(tb, da1, db1, idesc1, false); asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); __syncthreads(); if (wid == 0) tmem_read_print_16(tb, 0, "Phase5 SS#1 S[0,0..15]"); // SS #2 → cols 128-255 (separate output region) if (tid == 0) umma_ss_f16(tb + 128, da1, db1, idesc1, false); asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); __syncthreads(); if (wid == 0) tmem_read_print_16(tb, 128, "Phase5 SS#2 S[0,0..15]"); __syncthreads(); if (wid == 0) tmem_dealloc(tb, 256); } // ============================================================ // PHASE 6: Two TS calls in sequence // Does TS→TS work? // ============================================================ __global__ void __launch_bounds__(128) test_phase6_ts_ts() { const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32; extern __shared__ char sbuf[]; uint32_t* sTmemBase = (uint32_t*)sbuf; bf16_t* sV = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15); fill_smem_16x16_ones(sV, 2.0f); if (wid == 0) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 128); __syncthreads(); uint32_t tb = *sTmemBase; // Write A1 at cols 0-15, A2 at cols 64-79 if (wid == 0) { tmem_write_ones_128x16(tb, 0, 1.0f); tmem_write_ones_128x16(tb, 64, 3.0f); tmem_fence_store(); } __syncthreads(); // TS #1: A1 × V → C1 at cols 32-47 uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16); uint32_t idesc = make_idesc(128, 16); if (tid == 0) umma_ts_f16(tb + 32, tb, dv, idesc, false); asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); __syncthreads(); if (wid == 0) tmem_read_print_16(tb, 32, "Phase6 TS#1 C[0,0..15]"); // TS #2: A2 × V → C2 at cols 96-111 if (tid == 0) umma_ts_f16(tb + 96, tb + 64, dv, idesc, false); asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); __syncthreads(); if (wid == 0) tmem_read_print_16(tb, 96, "Phase6 TS#2 C[0,0..15]"); __syncthreads(); if (wid == 0) tmem_dealloc(tb, 128); } int main() { printf("=== SS + TS Sequence Test ===\n\n"); auto run = [](const char* name, void (*kernel)(), int smem) { printf("--- %s ---\n", name); kernel<<<1, 128, smem>>>(); cudaError_t err = cudaDeviceSynchronize(); if (err != cudaSuccess) { printf(" CRASH: %s\n\n", cudaGetErrorString(err)); // Reset GPU state for next test cudaDeviceReset(); } else { printf(" PASS\n\n"); } }; // Phase 1: SS alone int smem1 = (4+16 + 2*128*16*2 + 256 + 127) & ~127; run("Phase 1: SS alone", test_phase1_ss, smem1); // Phase 2: TS alone int smem2 = (4+16 + 16*16*2 + 256 + 127) & ~127; run("Phase 2: TS alone", test_phase2_ts, smem2); // Phase 3: SS then TS (the crash case) int smem3 = (4+16 + 2*128*16*2 + 16*16*2 + 256 + 127) & ~127; run("Phase 3: SS → TS (same TMEM)", test_phase3_ss_then_ts, smem3); // Phase 4: SS then TS with dealloc/realloc run("Phase 4: SS → dealloc → TS (separate TMEM)", test_phase4_ss_ts_separate_tmem, smem3); // Phase 5: SS → SS int smem5 = (4+16 + 2*128*16*2 + 256 + 127) & ~127; run("Phase 5: SS → SS", test_phase5_ss_ss, smem5); // Phase 6: TS → TS int smem6 = (4+16 + 16*16*2 + 256 + 127) & ~127; run("Phase 6: TS → TS", test_phase6_ts_ts, smem6); return 0; }