440 lines
16 KiB
Plaintext
440 lines
16 KiB
Plaintext
/**
|
||
* Systematic test: Can tcgen05.mma SS and TS coexist in the same kernel?
|
||
*
|
||
* The crash from combining QK (SS) + PV (TS) in the same kernel suggests
|
||
* a TMEM state issue. This test systematically varies:
|
||
* 1. TMEM allocation size
|
||
* 2. Whether SS and TS share the same TMEM region or use separate regions
|
||
* 3. Whether there's a TMEM dealloc+realloc between SS and TS
|
||
* 4. Whether the issue is SS→TS specifically, or any second MMA call
|
||
*
|
||
* The working test_mma_ts.cu uses TS alone and works. test_fmha_hd64.cu
|
||
* uses SS alone (with register-math PV) and works. The crash only happens
|
||
* when both are in the same kernel.
|
||
*
|
||
* Key hypothesis: After tcgen05.mma SS, the TMEM C operand's internal
|
||
* state may conflict with the TS MMA's A operand read. The SS MMA writes
|
||
* to TMEM columns in Layout D format. The TS MMA reads A from TMEM and
|
||
* expects it in a specific format. If the formats don't match, or if
|
||
* there's a hardware fence missing between SS and TS, that could cause
|
||
* the crash.
|
||
*
|
||
* Alternative hypothesis: The issue is simpler — TMEM column accounting.
|
||
* SS writes 128 columns. TS reads from columns 0-15 (subset of P) and
|
||
* writes to columns 128-143 (O region). Maybe the SS MMA reserves or
|
||
* locks TMEM columns in a way that TS can't access.
|
||
*
|
||
* Test plan:
|
||
* Phase 1: SS alone (baseline — should work)
|
||
* Phase 2: TS alone (baseline — should work)
|
||
* Phase 3: SS then TS, same TMEM (the crash case)
|
||
* Phase 4: SS then TS, separate TMEM allocs (dealloc after SS, realloc for TS)
|
||
* Phase 5: Two SS calls in sequence (does SS→SS work?)
|
||
* Phase 6: Two TS calls in sequence (does TS→TS work?)
|
||
* Phase 7: SS then TS with explicit TMEM barrier
|
||
*/
|
||
|
||
#include <cuda_runtime.h>
|
||
#include <cstdio>
|
||
#include <cmath>
|
||
#include <cstring>
|
||
|
||
#include "dsv4/kernels/attention/fmha_common.cuh"
|
||
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
|
||
|
||
using namespace dsv4::kernels::attention;
|
||
|
||
// ============================================================
|
||
// Helper: fill SMEM (16,16) canonical with all-1s
|
||
// ============================================================
|
||
__device__ void fill_smem_16x16_ones(bf16_t* s, float val = 1.0f) {
|
||
for (int i = threadIdx.x; i < 16 * 16; i += 128) s[i] = 0;
|
||
__syncthreads();
|
||
for (int i = threadIdx.x; i < 16 * 16; i += 128) {
|
||
int r = i / 16, c = i % 16;
|
||
int ck = c / 8, lc = c % 8;
|
||
int tmn = r / 8, lr = r % 8;
|
||
s[ck * 2 * 64 + tmn * 64 + lr * 8 + lc] = f32_to_bf16(val);
|
||
}
|
||
__syncthreads();
|
||
}
|
||
|
||
// ============================================================
|
||
// Helper: fill SMEM (128,16) canonical with all-1s (for Q/K)
|
||
// ============================================================
|
||
__device__ void fill_smem_128x16_ones(bf16_t* s, float val = 1.0f) {
|
||
constexpr int CORES_MN = 16, CORES_K = 2;
|
||
for (int i = threadIdx.x; i < 128 * 16; i += 128) s[i] = 0;
|
||
__syncthreads();
|
||
for (int i = threadIdx.x; i < 128 * 16; i += 128) {
|
||
int r = i / 16, c = i % 16;
|
||
int ck = c / 8, lc = c % 8;
|
||
int tmn = r / 8, lr = r % 8;
|
||
s[ck * CORES_MN * 64 + tmn * 64 + lr * 8 + lc] = f32_to_bf16(val);
|
||
}
|
||
__syncthreads();
|
||
}
|
||
|
||
// ============================================================
|
||
// Helper: write A=all-1s to TMEM cols [col_start, col_start+16)
|
||
// using 32x32b.x8 stores (warp-collective, all lanes write 1.0)
|
||
// ============================================================
|
||
__device__ void tmem_write_ones_128x16(uint32_t tb, int col_start, float val = 1.0f) {
|
||
for (int n = 0; n < 2; n++) { // 16 cols / 8 = 2 iterations
|
||
float p0=val, p1=val, p2=val, p3=val, p4=val, p5=val, p6=val, p7=val;
|
||
asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32 [%0],{%1,%2,%3,%4,%5,%6,%7,%8};"
|
||
:: "r"(tb + col_start + n*8),
|
||
"f"(p0),"f"(p1),"f"(p2),"f"(p3),"f"(p4),"f"(p5),"f"(p6),"f"(p7));
|
||
}
|
||
}
|
||
|
||
// ============================================================
|
||
// Helper: read 16 TMEM cols starting at col_start, print row 0
|
||
// ============================================================
|
||
__device__ void tmem_read_print_16(uint32_t tb, int col_start, const char* label) {
|
||
float vals[16];
|
||
int wid = threadIdx.x / 32, lane = threadIdx.x % 32;
|
||
for (int n = 0; n < 2; n++) {
|
||
float tmp[8];
|
||
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
|
||
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
|
||
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
|
||
: "r"(tb + col_start + n*8));
|
||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||
if (lane == 0) for (int c=0;c<8;c++) vals[n*8+c] = tmp[c];
|
||
}
|
||
if (lane == 0) {
|
||
printf("%s: ", label);
|
||
for (int c=0;c<16;c++) printf("%.1f ", vals[c]);
|
||
printf("\n");
|
||
}
|
||
}
|
||
|
||
// ============================================================
|
||
// PHASE 1: SS alone
|
||
// ============================================================
|
||
__global__ void __launch_bounds__(128)
|
||
test_phase1_ss()
|
||
{
|
||
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
|
||
|
||
extern __shared__ char sbuf[];
|
||
uint32_t* sTmemBase = (uint32_t*)sbuf;
|
||
bf16_t* sA = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
|
||
bf16_t* sB = sA + 128 * 16; // Second (128,16) buffer for B
|
||
|
||
fill_smem_128x16_ones(sA, 1.0f);
|
||
fill_smem_128x16_ones(sB, 2.0f);
|
||
|
||
if (wid == 0) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 128);
|
||
__syncthreads();
|
||
uint32_t tb = *sTmemBase;
|
||
|
||
// SS: A(128,16) × B(128,16) → C(128,128) at tb
|
||
uint64_t da = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sA), 128);
|
||
uint64_t db = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sB), 128);
|
||
uint32_t idesc = make_idesc(128, 128);
|
||
if (tid == 0) umma_ss_f16(tb, da, db, idesc, false);
|
||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||
__syncthreads();
|
||
|
||
// Read first 16 cols of S
|
||
if (wid == 0) tmem_read_print_16(tb, 0, "Phase1 SS S[0,0..15]");
|
||
__syncthreads();
|
||
|
||
if (wid == 0) tmem_dealloc(tb, 128);
|
||
}
|
||
|
||
// ============================================================
|
||
// PHASE 2: TS alone
|
||
// ============================================================
|
||
__global__ void __launch_bounds__(128)
|
||
test_phase2_ts()
|
||
{
|
||
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
|
||
|
||
extern __shared__ char sbuf[];
|
||
uint32_t* sTmemBase = (uint32_t*)sbuf;
|
||
bf16_t* sV = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
|
||
|
||
fill_smem_16x16_ones(sV, 2.0f);
|
||
|
||
if (wid == 0) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 64);
|
||
__syncthreads();
|
||
uint32_t tb = *sTmemBase;
|
||
|
||
// Write A = all 1.0 to cols 0-15
|
||
if (wid == 0) { tmem_write_ones_128x16(tb, 0, 1.0f); tmem_fence_store(); }
|
||
__syncthreads();
|
||
|
||
// TS: A(128,16, TMEM) × B(16,16, SMEM) → C(128,16, TMEM) at tb+32
|
||
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16);
|
||
uint32_t idesc = make_idesc(128, 16);
|
||
if (tid == 0) umma_ts_f16(tb + 32, tb, dv, idesc, false);
|
||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||
__syncthreads();
|
||
|
||
// Read C
|
||
if (wid == 0) tmem_read_print_16(tb, 32, "Phase2 TS C[0,0..15]");
|
||
__syncthreads();
|
||
|
||
if (wid == 0) tmem_dealloc(tb, 64);
|
||
}
|
||
|
||
// ============================================================
|
||
// PHASE 3: SS then TS, same TMEM allocation
|
||
// This is the crash case we need to debug.
|
||
// ============================================================
|
||
__global__ void __launch_bounds__(128)
|
||
test_phase3_ss_then_ts()
|
||
{
|
||
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
|
||
|
||
extern __shared__ char sbuf[];
|
||
uint32_t* sTmemBase = (uint32_t*)sbuf;
|
||
bf16_t* sAB = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
|
||
bf16_t* sV = (bf16_t*)(((uintptr_t)(sAB + 2 * 128 * 16) + 127) & ~(uintptr_t)127);
|
||
|
||
// SS buffers
|
||
fill_smem_128x16_ones(sAB, 1.0f);
|
||
fill_smem_128x16_ones(sAB + 128 * 16, 2.0f);
|
||
// TS buffer
|
||
fill_smem_16x16_ones(sV, 2.0f);
|
||
|
||
// TMEM: 256 columns. 0-127 = SS output (S). 128-143 = TS output (O).
|
||
if (wid == 0) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 256);
|
||
__syncthreads();
|
||
uint32_t tb = *sTmemBase;
|
||
|
||
// STEP 1: SS QK GEMM → S at tb (cols 0-127)
|
||
uint64_t da = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sAB), 128);
|
||
uint64_t db = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sAB + 128 * 16), 128);
|
||
uint32_t idesc_ss = make_idesc(128, 128);
|
||
if (tid == 0) umma_ss_f16(tb, da, db, idesc_ss, false);
|
||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||
__syncthreads();
|
||
|
||
// Verify SS output
|
||
if (wid == 0) tmem_read_print_16(tb, 0, "Phase3 SS S[0,0..15]");
|
||
|
||
// STEP 2: TS PV GEMM → O at tb+128 (cols 128-143)
|
||
// A = first 16 cols of S (tb + 0)
|
||
// B = sV (16, 16)
|
||
// C = O at tb + 128
|
||
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16);
|
||
uint32_t idesc_ts = make_idesc(128, 16);
|
||
if (tid == 0) umma_ts_f16(tb + 128, tb, dv, idesc_ts, false);
|
||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||
__syncthreads();
|
||
|
||
// Read TS output
|
||
if (wid == 0) tmem_read_print_16(tb, 128, "Phase3 TS O[0,0..15]");
|
||
__syncthreads();
|
||
|
||
if (wid == 0) tmem_dealloc(tb, 256);
|
||
}
|
||
|
||
// ============================================================
|
||
// PHASE 4: SS then TS, with TMEM dealloc + realloc between them
|
||
// Tests whether the crash is due to TMEM state from SS interfering with TS
|
||
// ============================================================
|
||
__global__ void __launch_bounds__(128)
|
||
test_phase4_ss_ts_separate_tmem()
|
||
{
|
||
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
|
||
|
||
extern __shared__ char sbuf[];
|
||
uint32_t* sTmemBase = (uint32_t*)sbuf;
|
||
bf16_t* sAB = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
|
||
bf16_t* sV = (bf16_t*)(((uintptr_t)(sAB + 2 * 128 * 16) + 127) & ~(uintptr_t)127);
|
||
|
||
fill_smem_128x16_ones(sAB, 1.0f);
|
||
fill_smem_128x16_ones(sAB + 128 * 16, 2.0f);
|
||
fill_smem_16x16_ones(sV, 2.0f);
|
||
|
||
// STEP 1: SS with 128-col TMEM
|
||
if (wid == 0) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 128);
|
||
__syncthreads();
|
||
uint32_t tb1 = *sTmemBase;
|
||
|
||
uint64_t da = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sAB), 128);
|
||
uint64_t db = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sAB + 128 * 16), 128);
|
||
uint32_t idesc_ss = make_idesc(128, 128);
|
||
if (tid == 0) umma_ss_f16(tb1, da, db, idesc_ss, false);
|
||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||
__syncthreads();
|
||
|
||
if (wid == 0) tmem_read_print_16(tb1, 0, "Phase4 SS S[0,0..15]");
|
||
|
||
// Save S values before dealloc (just row 0, first 16 cols)
|
||
float saved_s[16] = {0};
|
||
if (wid == 0 && lane == 0) {
|
||
for (int n = 0; n < 2; n++) {
|
||
float tmp[8];
|
||
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
|
||
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
|
||
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
|
||
: "r"(tb1 + n*8));
|
||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||
for (int c=0;c<8;c++) saved_s[n*8+c] = tmp[c];
|
||
}
|
||
}
|
||
__syncthreads();
|
||
|
||
// Dealloc SS TMEM
|
||
if (wid == 0) tmem_dealloc(tb1, 128);
|
||
__syncthreads();
|
||
|
||
// STEP 2: Realloc TMEM for TS — 64 cols
|
||
if (wid == 0) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 64);
|
||
__syncthreads();
|
||
uint32_t tb2 = *sTmemBase;
|
||
|
||
// Write A = all 1.0 to new TMEM cols 0-15
|
||
if (wid == 0) { tmem_write_ones_128x16(tb2, 0, 1.0f); tmem_fence_store(); }
|
||
__syncthreads();
|
||
|
||
// TS MMA
|
||
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16);
|
||
uint32_t idesc_ts = make_idesc(128, 16);
|
||
if (tid == 0) umma_ts_f16(tb2 + 32, tb2, dv, idesc_ts, false);
|
||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||
__syncthreads();
|
||
|
||
if (wid == 0) tmem_read_print_16(tb2, 32, "Phase4 TS C[0,0..15]");
|
||
__syncthreads();
|
||
|
||
if (wid == 0) tmem_dealloc(tb2, 64);
|
||
}
|
||
|
||
// ============================================================
|
||
// PHASE 5: Two SS calls in sequence
|
||
// Does SS→SS work?
|
||
// ============================================================
|
||
__global__ void __launch_bounds__(128)
|
||
test_phase5_ss_ss()
|
||
{
|
||
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
|
||
|
||
extern __shared__ char sbuf[];
|
||
uint32_t* sTmemBase = (uint32_t*)sbuf;
|
||
bf16_t* sA1 = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
|
||
bf16_t* sB1 = sA1 + 128 * 16;
|
||
|
||
fill_smem_128x16_ones(sA1, 1.0f);
|
||
fill_smem_128x16_ones(sB1, 2.0f);
|
||
|
||
if (wid == 0) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 256);
|
||
__syncthreads();
|
||
uint32_t tb = *sTmemBase;
|
||
|
||
// SS #1 → cols 0-127
|
||
uint64_t da1 = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sA1), 128);
|
||
uint64_t db1 = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sB1), 128);
|
||
uint32_t idesc1 = make_idesc(128, 128);
|
||
if (tid == 0) umma_ss_f16(tb, da1, db1, idesc1, false);
|
||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||
__syncthreads();
|
||
|
||
if (wid == 0) tmem_read_print_16(tb, 0, "Phase5 SS#1 S[0,0..15]");
|
||
|
||
// SS #2 → cols 128-255 (separate output region)
|
||
if (tid == 0) umma_ss_f16(tb + 128, da1, db1, idesc1, false);
|
||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||
__syncthreads();
|
||
|
||
if (wid == 0) tmem_read_print_16(tb, 128, "Phase5 SS#2 S[0,0..15]");
|
||
__syncthreads();
|
||
|
||
if (wid == 0) tmem_dealloc(tb, 256);
|
||
}
|
||
|
||
// ============================================================
|
||
// PHASE 6: Two TS calls in sequence
|
||
// Does TS→TS work?
|
||
// ============================================================
|
||
__global__ void __launch_bounds__(128)
|
||
test_phase6_ts_ts()
|
||
{
|
||
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
|
||
|
||
extern __shared__ char sbuf[];
|
||
uint32_t* sTmemBase = (uint32_t*)sbuf;
|
||
bf16_t* sV = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
|
||
|
||
fill_smem_16x16_ones(sV, 2.0f);
|
||
|
||
if (wid == 0) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 128);
|
||
__syncthreads();
|
||
uint32_t tb = *sTmemBase;
|
||
|
||
// Write A1 at cols 0-15, A2 at cols 64-79
|
||
if (wid == 0) {
|
||
tmem_write_ones_128x16(tb, 0, 1.0f);
|
||
tmem_write_ones_128x16(tb, 64, 3.0f);
|
||
tmem_fence_store();
|
||
}
|
||
__syncthreads();
|
||
|
||
// TS #1: A1 × V → C1 at cols 32-47
|
||
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16);
|
||
uint32_t idesc = make_idesc(128, 16);
|
||
if (tid == 0) umma_ts_f16(tb + 32, tb, dv, idesc, false);
|
||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||
__syncthreads();
|
||
|
||
if (wid == 0) tmem_read_print_16(tb, 32, "Phase6 TS#1 C[0,0..15]");
|
||
|
||
// TS #2: A2 × V → C2 at cols 96-111
|
||
if (tid == 0) umma_ts_f16(tb + 96, tb + 64, dv, idesc, false);
|
||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||
__syncthreads();
|
||
|
||
if (wid == 0) tmem_read_print_16(tb, 96, "Phase6 TS#2 C[0,0..15]");
|
||
__syncthreads();
|
||
|
||
if (wid == 0) tmem_dealloc(tb, 128);
|
||
}
|
||
|
||
int main() {
|
||
printf("=== SS + TS Sequence Test ===\n\n");
|
||
|
||
auto run = [](const char* name, void (*kernel)(), int smem) {
|
||
printf("--- %s ---\n", name);
|
||
kernel<<<1, 128, smem>>>();
|
||
cudaError_t err = cudaDeviceSynchronize();
|
||
if (err != cudaSuccess) {
|
||
printf(" CRASH: %s\n\n", cudaGetErrorString(err));
|
||
// Reset GPU state for next test
|
||
cudaDeviceReset();
|
||
} else {
|
||
printf(" PASS\n\n");
|
||
}
|
||
};
|
||
|
||
// Phase 1: SS alone
|
||
int smem1 = (4+16 + 2*128*16*2 + 256 + 127) & ~127;
|
||
run("Phase 1: SS alone", test_phase1_ss, smem1);
|
||
|
||
// Phase 2: TS alone
|
||
int smem2 = (4+16 + 16*16*2 + 256 + 127) & ~127;
|
||
run("Phase 2: TS alone", test_phase2_ts, smem2);
|
||
|
||
// Phase 3: SS then TS (the crash case)
|
||
int smem3 = (4+16 + 2*128*16*2 + 16*16*2 + 256 + 127) & ~127;
|
||
run("Phase 3: SS → TS (same TMEM)", test_phase3_ss_then_ts, smem3);
|
||
|
||
// Phase 4: SS then TS with dealloc/realloc
|
||
run("Phase 4: SS → dealloc → TS (separate TMEM)", test_phase4_ss_ts_separate_tmem, smem3);
|
||
|
||
// Phase 5: SS → SS
|
||
int smem5 = (4+16 + 2*128*16*2 + 256 + 127) & ~127;
|
||
run("Phase 5: SS → SS", test_phase5_ss_ss, smem5);
|
||
|
||
// Phase 6: TS → TS
|
||
int smem6 = (4+16 + 16*16*2 + 256 + 127) & ~127;
|
||
run("Phase 6: TS → TS", test_phase6_ts_ts, smem6);
|
||
|
||
return 0;
|
||
}
|