Files
nvfp4-megamoe-kernel/tests/unit/test_ss_ts_sequence.cu

440 lines
16 KiB
Plaintext
Raw Normal View History

/**
* Systematic test: Can tcgen05.mma SS and TS coexist in the same kernel?
*
* The crash from combining QK (SS) + PV (TS) in the same kernel suggests
* a TMEM state issue. This test systematically varies:
* 1. TMEM allocation size
* 2. Whether SS and TS share the same TMEM region or use separate regions
* 3. Whether there's a TMEM dealloc+realloc between SS and TS
* 4. Whether the issue is SS→TS specifically, or any second MMA call
*
* The working test_mma_ts.cu uses TS alone and works. test_fmha_hd64.cu
* uses SS alone (with register-math PV) and works. The crash only happens
* when both are in the same kernel.
*
* Key hypothesis: After tcgen05.mma SS, the TMEM C operand's internal
* state may conflict with the TS MMA's A operand read. The SS MMA writes
* to TMEM columns in Layout D format. The TS MMA reads A from TMEM and
* expects it in a specific format. If the formats don't match, or if
* there's a hardware fence missing between SS and TS, that could cause
* the crash.
*
* Alternative hypothesis: The issue is simpler — TMEM column accounting.
* SS writes 128 columns. TS reads from columns 0-15 (subset of P) and
* writes to columns 128-143 (O region). Maybe the SS MMA reserves or
* locks TMEM columns in a way that TS can't access.
*
* Test plan:
* Phase 1: SS alone (baseline — should work)
* Phase 2: TS alone (baseline — should work)
* Phase 3: SS then TS, same TMEM (the crash case)
* Phase 4: SS then TS, separate TMEM allocs (dealloc after SS, realloc for TS)
* Phase 5: Two SS calls in sequence (does SS→SS work?)
* Phase 6: Two TS calls in sequence (does TS→TS work?)
* Phase 7: SS then TS with explicit TMEM barrier
*/
#include <cuda_runtime.h>
#include <cstdio>
#include <cmath>
#include <cstring>
#include "dsv4/kernels/attention/fmha_common.cuh"
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
using namespace dsv4::kernels::attention;
// ============================================================
// Helper: fill SMEM (16,16) canonical with all-1s
// ============================================================
__device__ void fill_smem_16x16_ones(bf16_t* s, float val = 1.0f) {
for (int i = threadIdx.x; i < 16 * 16; i += 128) s[i] = 0;
__syncthreads();
for (int i = threadIdx.x; i < 16 * 16; i += 128) {
int r = i / 16, c = i % 16;
int ck = c / 8, lc = c % 8;
int tmn = r / 8, lr = r % 8;
s[ck * 2 * 64 + tmn * 64 + lr * 8 + lc] = f32_to_bf16(val);
}
__syncthreads();
}
// ============================================================
// Helper: fill SMEM (128,16) canonical with all-1s (for Q/K)
// ============================================================
__device__ void fill_smem_128x16_ones(bf16_t* s, float val = 1.0f) {
constexpr int CORES_MN = 16, CORES_K = 2;
for (int i = threadIdx.x; i < 128 * 16; i += 128) s[i] = 0;
__syncthreads();
for (int i = threadIdx.x; i < 128 * 16; i += 128) {
int r = i / 16, c = i % 16;
int ck = c / 8, lc = c % 8;
int tmn = r / 8, lr = r % 8;
s[ck * CORES_MN * 64 + tmn * 64 + lr * 8 + lc] = f32_to_bf16(val);
}
__syncthreads();
}
// ============================================================
// Helper: write A=all-1s to TMEM cols [col_start, col_start+16)
// using 32x32b.x8 stores (warp-collective, all lanes write 1.0)
// ============================================================
__device__ void tmem_write_ones_128x16(uint32_t tb, int col_start, float val = 1.0f) {
for (int n = 0; n < 2; n++) { // 16 cols / 8 = 2 iterations
float p0=val, p1=val, p2=val, p3=val, p4=val, p5=val, p6=val, p7=val;
asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32 [%0],{%1,%2,%3,%4,%5,%6,%7,%8};"
:: "r"(tb + col_start + n*8),
"f"(p0),"f"(p1),"f"(p2),"f"(p3),"f"(p4),"f"(p5),"f"(p6),"f"(p7));
}
}
// ============================================================
// Helper: read 16 TMEM cols starting at col_start, print row 0
// ============================================================
__device__ void tmem_read_print_16(uint32_t tb, int col_start, const char* label) {
float vals[16];
int wid = threadIdx.x / 32, lane = threadIdx.x % 32;
for (int n = 0; n < 2; n++) {
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
: "r"(tb + col_start + n*8));
asm volatile("tcgen05.wait::ld.sync.aligned;");
if (lane == 0) for (int c=0;c<8;c++) vals[n*8+c] = tmp[c];
}
if (lane == 0) {
printf("%s: ", label);
for (int c=0;c<16;c++) printf("%.1f ", vals[c]);
printf("\n");
}
}
// ============================================================
// PHASE 1: SS alone
// ============================================================
__global__ void __launch_bounds__(128)
test_phase1_ss()
{
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
bf16_t* sA = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
bf16_t* sB = sA + 128 * 16; // Second (128,16) buffer for B
fill_smem_128x16_ones(sA, 1.0f);
fill_smem_128x16_ones(sB, 2.0f);
if (wid == 0) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 128);
__syncthreads();
uint32_t tb = *sTmemBase;
// SS: A(128,16) × B(128,16) → C(128,128) at tb
uint64_t da = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sA), 128);
uint64_t db = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sB), 128);
uint32_t idesc = make_idesc(128, 128);
if (tid == 0) umma_ss_f16(tb, da, db, idesc, false);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
// Read first 16 cols of S
if (wid == 0) tmem_read_print_16(tb, 0, "Phase1 SS S[0,0..15]");
__syncthreads();
if (wid == 0) tmem_dealloc(tb, 128);
}
// ============================================================
// PHASE 2: TS alone
// ============================================================
__global__ void __launch_bounds__(128)
test_phase2_ts()
{
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
bf16_t* sV = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
fill_smem_16x16_ones(sV, 2.0f);
if (wid == 0) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 64);
__syncthreads();
uint32_t tb = *sTmemBase;
// Write A = all 1.0 to cols 0-15
if (wid == 0) { tmem_write_ones_128x16(tb, 0, 1.0f); tmem_fence_store(); }
__syncthreads();
// TS: A(128,16, TMEM) × B(16,16, SMEM) → C(128,16, TMEM) at tb+32
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16);
uint32_t idesc = make_idesc(128, 16);
if (tid == 0) umma_ts_f16(tb + 32, tb, dv, idesc, false);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
// Read C
if (wid == 0) tmem_read_print_16(tb, 32, "Phase2 TS C[0,0..15]");
__syncthreads();
if (wid == 0) tmem_dealloc(tb, 64);
}
// ============================================================
// PHASE 3: SS then TS, same TMEM allocation
// This is the crash case we need to debug.
// ============================================================
__global__ void __launch_bounds__(128)
test_phase3_ss_then_ts()
{
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
bf16_t* sAB = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
bf16_t* sV = (bf16_t*)(((uintptr_t)(sAB + 2 * 128 * 16) + 127) & ~(uintptr_t)127);
// SS buffers
fill_smem_128x16_ones(sAB, 1.0f);
fill_smem_128x16_ones(sAB + 128 * 16, 2.0f);
// TS buffer
fill_smem_16x16_ones(sV, 2.0f);
// TMEM: 256 columns. 0-127 = SS output (S). 128-143 = TS output (O).
if (wid == 0) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 256);
__syncthreads();
uint32_t tb = *sTmemBase;
// STEP 1: SS QK GEMM → S at tb (cols 0-127)
uint64_t da = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sAB), 128);
uint64_t db = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sAB + 128 * 16), 128);
uint32_t idesc_ss = make_idesc(128, 128);
if (tid == 0) umma_ss_f16(tb, da, db, idesc_ss, false);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
// Verify SS output
if (wid == 0) tmem_read_print_16(tb, 0, "Phase3 SS S[0,0..15]");
// STEP 2: TS PV GEMM → O at tb+128 (cols 128-143)
// A = first 16 cols of S (tb + 0)
// B = sV (16, 16)
// C = O at tb + 128
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16);
uint32_t idesc_ts = make_idesc(128, 16);
if (tid == 0) umma_ts_f16(tb + 128, tb, dv, idesc_ts, false);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
// Read TS output
if (wid == 0) tmem_read_print_16(tb, 128, "Phase3 TS O[0,0..15]");
__syncthreads();
if (wid == 0) tmem_dealloc(tb, 256);
}
// ============================================================
// PHASE 4: SS then TS, with TMEM dealloc + realloc between them
// Tests whether the crash is due to TMEM state from SS interfering with TS
// ============================================================
__global__ void __launch_bounds__(128)
test_phase4_ss_ts_separate_tmem()
{
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
bf16_t* sAB = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
bf16_t* sV = (bf16_t*)(((uintptr_t)(sAB + 2 * 128 * 16) + 127) & ~(uintptr_t)127);
fill_smem_128x16_ones(sAB, 1.0f);
fill_smem_128x16_ones(sAB + 128 * 16, 2.0f);
fill_smem_16x16_ones(sV, 2.0f);
// STEP 1: SS with 128-col TMEM
if (wid == 0) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 128);
__syncthreads();
uint32_t tb1 = *sTmemBase;
uint64_t da = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sAB), 128);
uint64_t db = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sAB + 128 * 16), 128);
uint32_t idesc_ss = make_idesc(128, 128);
if (tid == 0) umma_ss_f16(tb1, da, db, idesc_ss, false);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
if (wid == 0) tmem_read_print_16(tb1, 0, "Phase4 SS S[0,0..15]");
// Save S values before dealloc (just row 0, first 16 cols)
float saved_s[16] = {0};
if (wid == 0 && lane == 0) {
for (int n = 0; n < 2; n++) {
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
: "r"(tb1 + n*8));
asm volatile("tcgen05.wait::ld.sync.aligned;");
for (int c=0;c<8;c++) saved_s[n*8+c] = tmp[c];
}
}
__syncthreads();
// Dealloc SS TMEM
if (wid == 0) tmem_dealloc(tb1, 128);
__syncthreads();
// STEP 2: Realloc TMEM for TS — 64 cols
if (wid == 0) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 64);
__syncthreads();
uint32_t tb2 = *sTmemBase;
// Write A = all 1.0 to new TMEM cols 0-15
if (wid == 0) { tmem_write_ones_128x16(tb2, 0, 1.0f); tmem_fence_store(); }
__syncthreads();
// TS MMA
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16);
uint32_t idesc_ts = make_idesc(128, 16);
if (tid == 0) umma_ts_f16(tb2 + 32, tb2, dv, idesc_ts, false);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
if (wid == 0) tmem_read_print_16(tb2, 32, "Phase4 TS C[0,0..15]");
__syncthreads();
if (wid == 0) tmem_dealloc(tb2, 64);
}
// ============================================================
// PHASE 5: Two SS calls in sequence
// Does SS→SS work?
// ============================================================
__global__ void __launch_bounds__(128)
test_phase5_ss_ss()
{
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
bf16_t* sA1 = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
bf16_t* sB1 = sA1 + 128 * 16;
fill_smem_128x16_ones(sA1, 1.0f);
fill_smem_128x16_ones(sB1, 2.0f);
if (wid == 0) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 256);
__syncthreads();
uint32_t tb = *sTmemBase;
// SS #1 → cols 0-127
uint64_t da1 = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sA1), 128);
uint64_t db1 = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sB1), 128);
uint32_t idesc1 = make_idesc(128, 128);
if (tid == 0) umma_ss_f16(tb, da1, db1, idesc1, false);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
if (wid == 0) tmem_read_print_16(tb, 0, "Phase5 SS#1 S[0,0..15]");
// SS #2 → cols 128-255 (separate output region)
if (tid == 0) umma_ss_f16(tb + 128, da1, db1, idesc1, false);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
if (wid == 0) tmem_read_print_16(tb, 128, "Phase5 SS#2 S[0,0..15]");
__syncthreads();
if (wid == 0) tmem_dealloc(tb, 256);
}
// ============================================================
// PHASE 6: Two TS calls in sequence
// Does TS→TS work?
// ============================================================
__global__ void __launch_bounds__(128)
test_phase6_ts_ts()
{
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
bf16_t* sV = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
fill_smem_16x16_ones(sV, 2.0f);
if (wid == 0) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 128);
__syncthreads();
uint32_t tb = *sTmemBase;
// Write A1 at cols 0-15, A2 at cols 64-79
if (wid == 0) {
tmem_write_ones_128x16(tb, 0, 1.0f);
tmem_write_ones_128x16(tb, 64, 3.0f);
tmem_fence_store();
}
__syncthreads();
// TS #1: A1 × V → C1 at cols 32-47
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16);
uint32_t idesc = make_idesc(128, 16);
if (tid == 0) umma_ts_f16(tb + 32, tb, dv, idesc, false);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
if (wid == 0) tmem_read_print_16(tb, 32, "Phase6 TS#1 C[0,0..15]");
// TS #2: A2 × V → C2 at cols 96-111
if (tid == 0) umma_ts_f16(tb + 96, tb + 64, dv, idesc, false);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
if (wid == 0) tmem_read_print_16(tb, 96, "Phase6 TS#2 C[0,0..15]");
__syncthreads();
if (wid == 0) tmem_dealloc(tb, 128);
}
int main() {
printf("=== SS + TS Sequence Test ===\n\n");
auto run = [](const char* name, void (*kernel)(), int smem) {
printf("--- %s ---\n", name);
kernel<<<1, 128, smem>>>();
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) {
printf(" CRASH: %s\n\n", cudaGetErrorString(err));
// Reset GPU state for next test
cudaDeviceReset();
} else {
printf(" PASS\n\n");
}
};
// Phase 1: SS alone
int smem1 = (4+16 + 2*128*16*2 + 256 + 127) & ~127;
run("Phase 1: SS alone", test_phase1_ss, smem1);
// Phase 2: TS alone
int smem2 = (4+16 + 16*16*2 + 256 + 127) & ~127;
run("Phase 2: TS alone", test_phase2_ts, smem2);
// Phase 3: SS then TS (the crash case)
int smem3 = (4+16 + 2*128*16*2 + 16*16*2 + 256 + 127) & ~127;
run("Phase 3: SS → TS (same TMEM)", test_phase3_ss_then_ts, smem3);
// Phase 4: SS then TS with dealloc/realloc
run("Phase 4: SS → dealloc → TS (separate TMEM)", test_phase4_ss_ts_separate_tmem, smem3);
// Phase 5: SS → SS
int smem5 = (4+16 + 2*128*16*2 + 256 + 127) & ~127;
run("Phase 5: SS → SS", test_phase5_ss_ss, smem5);
// Phase 6: TS → TS
int smem6 = (4+16 + 16*16*2 + 256 + 127) & ~127;
run("Phase 6: TS → TS", test_phase6_ts_ts, smem6);
return 0;
}