Files
nvfp4-megamoe-kernel/tests/unit/test_ss_ts_sequence.cu

440 lines
16 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
/**
* Systematic test: Can tcgen05.mma SS and TS coexist in the same kernel?
*
* The crash from combining QK (SS) + PV (TS) in the same kernel suggests
* a TMEM state issue. This test systematically varies:
* 1. TMEM allocation size
* 2. Whether SS and TS share the same TMEM region or use separate regions
* 3. Whether there's a TMEM dealloc+realloc between SS and TS
* 4. Whether the issue is SS→TS specifically, or any second MMA call
*
* The working test_mma_ts.cu uses TS alone and works. test_fmha_hd64.cu
* uses SS alone (with register-math PV) and works. The crash only happens
* when both are in the same kernel.
*
* Key hypothesis: After tcgen05.mma SS, the TMEM C operand's internal
* state may conflict with the TS MMA's A operand read. The SS MMA writes
* to TMEM columns in Layout D format. The TS MMA reads A from TMEM and
* expects it in a specific format. If the formats don't match, or if
* there's a hardware fence missing between SS and TS, that could cause
* the crash.
*
* Alternative hypothesis: The issue is simpler — TMEM column accounting.
* SS writes 128 columns. TS reads from columns 0-15 (subset of P) and
* writes to columns 128-143 (O region). Maybe the SS MMA reserves or
* locks TMEM columns in a way that TS can't access.
*
* Test plan:
* Phase 1: SS alone (baseline — should work)
* Phase 2: TS alone (baseline — should work)
* Phase 3: SS then TS, same TMEM (the crash case)
* Phase 4: SS then TS, separate TMEM allocs (dealloc after SS, realloc for TS)
* Phase 5: Two SS calls in sequence (does SS→SS work?)
* Phase 6: Two TS calls in sequence (does TS→TS work?)
* Phase 7: SS then TS with explicit TMEM barrier
*/
#include <cuda_runtime.h>
#include <cstdio>
#include <cmath>
#include <cstring>
#include "dsv4/kernels/attention/fmha_common.cuh"
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
using namespace dsv4::kernels::attention;
// ============================================================
// Helper: fill SMEM (16,16) canonical with all-1s
// ============================================================
__device__ void fill_smem_16x16_ones(bf16_t* s, float val = 1.0f) {
for (int i = threadIdx.x; i < 16 * 16; i += 128) s[i] = 0;
__syncthreads();
for (int i = threadIdx.x; i < 16 * 16; i += 128) {
int r = i / 16, c = i % 16;
int ck = c / 8, lc = c % 8;
int tmn = r / 8, lr = r % 8;
s[ck * 2 * 64 + tmn * 64 + lr * 8 + lc] = f32_to_bf16(val);
}
__syncthreads();
}
// ============================================================
// Helper: fill SMEM (128,16) canonical with all-1s (for Q/K)
// ============================================================
__device__ void fill_smem_128x16_ones(bf16_t* s, float val = 1.0f) {
constexpr int CORES_MN = 16, CORES_K = 2;
for (int i = threadIdx.x; i < 128 * 16; i += 128) s[i] = 0;
__syncthreads();
for (int i = threadIdx.x; i < 128 * 16; i += 128) {
int r = i / 16, c = i % 16;
int ck = c / 8, lc = c % 8;
int tmn = r / 8, lr = r % 8;
s[ck * CORES_MN * 64 + tmn * 64 + lr * 8 + lc] = f32_to_bf16(val);
}
__syncthreads();
}
// ============================================================
// Helper: write A=all-1s to TMEM cols [col_start, col_start+16)
// using 32x32b.x8 stores (warp-collective, all lanes write 1.0)
// ============================================================
__device__ void tmem_write_ones_128x16(uint32_t tb, int col_start, float val = 1.0f) {
for (int n = 0; n < 2; n++) { // 16 cols / 8 = 2 iterations
float p0=val, p1=val, p2=val, p3=val, p4=val, p5=val, p6=val, p7=val;
asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32 [%0],{%1,%2,%3,%4,%5,%6,%7,%8};"
:: "r"(tb + col_start + n*8),
"f"(p0),"f"(p1),"f"(p2),"f"(p3),"f"(p4),"f"(p5),"f"(p6),"f"(p7));
}
}
// ============================================================
// Helper: read 16 TMEM cols starting at col_start, print row 0
// ============================================================
__device__ void tmem_read_print_16(uint32_t tb, int col_start, const char* label) {
float vals[16];
int wid = threadIdx.x / 32, lane = threadIdx.x % 32;
for (int n = 0; n < 2; n++) {
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
: "r"(tb + col_start + n*8));
asm volatile("tcgen05.wait::ld.sync.aligned;");
if (lane == 0) for (int c=0;c<8;c++) vals[n*8+c] = tmp[c];
}
if (lane == 0) {
printf("%s: ", label);
for (int c=0;c<16;c++) printf("%.1f ", vals[c]);
printf("\n");
}
}
// ============================================================
// PHASE 1: SS alone
// ============================================================
__global__ void __launch_bounds__(128)
test_phase1_ss()
{
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
bf16_t* sA = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
bf16_t* sB = sA + 128 * 16; // Second (128,16) buffer for B
fill_smem_128x16_ones(sA, 1.0f);
fill_smem_128x16_ones(sB, 2.0f);
if (wid == 0) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 128);
__syncthreads();
uint32_t tb = *sTmemBase;
// SS: A(128,16) × B(128,16) → C(128,128) at tb
uint64_t da = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sA), 128);
uint64_t db = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sB), 128);
uint32_t idesc = make_idesc(128, 128);
if (tid == 0) umma_ss_f16(tb, da, db, idesc, false);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
// Read first 16 cols of S
if (wid == 0) tmem_read_print_16(tb, 0, "Phase1 SS S[0,0..15]");
__syncthreads();
if (wid == 0) tmem_dealloc(tb, 128);
}
// ============================================================
// PHASE 2: TS alone
// ============================================================
__global__ void __launch_bounds__(128)
test_phase2_ts()
{
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
bf16_t* sV = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
fill_smem_16x16_ones(sV, 2.0f);
if (wid == 0) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 64);
__syncthreads();
uint32_t tb = *sTmemBase;
// Write A = all 1.0 to cols 0-15
if (wid == 0) { tmem_write_ones_128x16(tb, 0, 1.0f); tmem_fence_store(); }
__syncthreads();
// TS: A(128,16, TMEM) × B(16,16, SMEM) → C(128,16, TMEM) at tb+32
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16);
uint32_t idesc = make_idesc(128, 16);
if (tid == 0) umma_ts_f16(tb + 32, tb, dv, idesc, false);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
// Read C
if (wid == 0) tmem_read_print_16(tb, 32, "Phase2 TS C[0,0..15]");
__syncthreads();
if (wid == 0) tmem_dealloc(tb, 64);
}
// ============================================================
// PHASE 3: SS then TS, same TMEM allocation
// This is the crash case we need to debug.
// ============================================================
__global__ void __launch_bounds__(128)
test_phase3_ss_then_ts()
{
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
bf16_t* sAB = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
bf16_t* sV = (bf16_t*)(((uintptr_t)(sAB + 2 * 128 * 16) + 127) & ~(uintptr_t)127);
// SS buffers
fill_smem_128x16_ones(sAB, 1.0f);
fill_smem_128x16_ones(sAB + 128 * 16, 2.0f);
// TS buffer
fill_smem_16x16_ones(sV, 2.0f);
// TMEM: 256 columns. 0-127 = SS output (S). 128-143 = TS output (O).
if (wid == 0) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 256);
__syncthreads();
uint32_t tb = *sTmemBase;
// STEP 1: SS QK GEMM → S at tb (cols 0-127)
uint64_t da = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sAB), 128);
uint64_t db = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sAB + 128 * 16), 128);
uint32_t idesc_ss = make_idesc(128, 128);
if (tid == 0) umma_ss_f16(tb, da, db, idesc_ss, false);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
// Verify SS output
if (wid == 0) tmem_read_print_16(tb, 0, "Phase3 SS S[0,0..15]");
// STEP 2: TS PV GEMM → O at tb+128 (cols 128-143)
// A = first 16 cols of S (tb + 0)
// B = sV (16, 16)
// C = O at tb + 128
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16);
uint32_t idesc_ts = make_idesc(128, 16);
if (tid == 0) umma_ts_f16(tb + 128, tb, dv, idesc_ts, false);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
// Read TS output
if (wid == 0) tmem_read_print_16(tb, 128, "Phase3 TS O[0,0..15]");
__syncthreads();
if (wid == 0) tmem_dealloc(tb, 256);
}
// ============================================================
// PHASE 4: SS then TS, with TMEM dealloc + realloc between them
// Tests whether the crash is due to TMEM state from SS interfering with TS
// ============================================================
__global__ void __launch_bounds__(128)
test_phase4_ss_ts_separate_tmem()
{
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
bf16_t* sAB = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
bf16_t* sV = (bf16_t*)(((uintptr_t)(sAB + 2 * 128 * 16) + 127) & ~(uintptr_t)127);
fill_smem_128x16_ones(sAB, 1.0f);
fill_smem_128x16_ones(sAB + 128 * 16, 2.0f);
fill_smem_16x16_ones(sV, 2.0f);
// STEP 1: SS with 128-col TMEM
if (wid == 0) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 128);
__syncthreads();
uint32_t tb1 = *sTmemBase;
uint64_t da = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sAB), 128);
uint64_t db = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sAB + 128 * 16), 128);
uint32_t idesc_ss = make_idesc(128, 128);
if (tid == 0) umma_ss_f16(tb1, da, db, idesc_ss, false);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
if (wid == 0) tmem_read_print_16(tb1, 0, "Phase4 SS S[0,0..15]");
// Save S values before dealloc (just row 0, first 16 cols)
float saved_s[16] = {0};
if (wid == 0 && lane == 0) {
for (int n = 0; n < 2; n++) {
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
: "r"(tb1 + n*8));
asm volatile("tcgen05.wait::ld.sync.aligned;");
for (int c=0;c<8;c++) saved_s[n*8+c] = tmp[c];
}
}
__syncthreads();
// Dealloc SS TMEM
if (wid == 0) tmem_dealloc(tb1, 128);
__syncthreads();
// STEP 2: Realloc TMEM for TS — 64 cols
if (wid == 0) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 64);
__syncthreads();
uint32_t tb2 = *sTmemBase;
// Write A = all 1.0 to new TMEM cols 0-15
if (wid == 0) { tmem_write_ones_128x16(tb2, 0, 1.0f); tmem_fence_store(); }
__syncthreads();
// TS MMA
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16);
uint32_t idesc_ts = make_idesc(128, 16);
if (tid == 0) umma_ts_f16(tb2 + 32, tb2, dv, idesc_ts, false);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
if (wid == 0) tmem_read_print_16(tb2, 32, "Phase4 TS C[0,0..15]");
__syncthreads();
if (wid == 0) tmem_dealloc(tb2, 64);
}
// ============================================================
// PHASE 5: Two SS calls in sequence
// Does SS→SS work?
// ============================================================
__global__ void __launch_bounds__(128)
test_phase5_ss_ss()
{
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
bf16_t* sA1 = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
bf16_t* sB1 = sA1 + 128 * 16;
fill_smem_128x16_ones(sA1, 1.0f);
fill_smem_128x16_ones(sB1, 2.0f);
if (wid == 0) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 256);
__syncthreads();
uint32_t tb = *sTmemBase;
// SS #1 → cols 0-127
uint64_t da1 = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sA1), 128);
uint64_t db1 = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sB1), 128);
uint32_t idesc1 = make_idesc(128, 128);
if (tid == 0) umma_ss_f16(tb, da1, db1, idesc1, false);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
if (wid == 0) tmem_read_print_16(tb, 0, "Phase5 SS#1 S[0,0..15]");
// SS #2 → cols 128-255 (separate output region)
if (tid == 0) umma_ss_f16(tb + 128, da1, db1, idesc1, false);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
if (wid == 0) tmem_read_print_16(tb, 128, "Phase5 SS#2 S[0,0..15]");
__syncthreads();
if (wid == 0) tmem_dealloc(tb, 256);
}
// ============================================================
// PHASE 6: Two TS calls in sequence
// Does TS→TS work?
// ============================================================
__global__ void __launch_bounds__(128)
test_phase6_ts_ts()
{
const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32;
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
bf16_t* sV = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15);
fill_smem_16x16_ones(sV, 2.0f);
if (wid == 0) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 128);
__syncthreads();
uint32_t tb = *sTmemBase;
// Write A1 at cols 0-15, A2 at cols 64-79
if (wid == 0) {
tmem_write_ones_128x16(tb, 0, 1.0f);
tmem_write_ones_128x16(tb, 64, 3.0f);
tmem_fence_store();
}
__syncthreads();
// TS #1: A1 × V → C1 at cols 32-47
uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sV), 16);
uint32_t idesc = make_idesc(128, 16);
if (tid == 0) umma_ts_f16(tb + 32, tb, dv, idesc, false);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
if (wid == 0) tmem_read_print_16(tb, 32, "Phase6 TS#1 C[0,0..15]");
// TS #2: A2 × V → C2 at cols 96-111
if (tid == 0) umma_ts_f16(tb + 96, tb + 64, dv, idesc, false);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
__syncthreads();
if (wid == 0) tmem_read_print_16(tb, 96, "Phase6 TS#2 C[0,0..15]");
__syncthreads();
if (wid == 0) tmem_dealloc(tb, 128);
}
int main() {
printf("=== SS + TS Sequence Test ===\n\n");
auto run = [](const char* name, void (*kernel)(), int smem) {
printf("--- %s ---\n", name);
kernel<<<1, 128, smem>>>();
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) {
printf(" CRASH: %s\n\n", cudaGetErrorString(err));
// Reset GPU state for next test
cudaDeviceReset();
} else {
printf(" PASS\n\n");
}
};
// Phase 1: SS alone
int smem1 = (4+16 + 2*128*16*2 + 256 + 127) & ~127;
run("Phase 1: SS alone", test_phase1_ss, smem1);
// Phase 2: TS alone
int smem2 = (4+16 + 16*16*2 + 256 + 127) & ~127;
run("Phase 2: TS alone", test_phase2_ts, smem2);
// Phase 3: SS then TS (the crash case)
int smem3 = (4+16 + 2*128*16*2 + 16*16*2 + 256 + 127) & ~127;
run("Phase 3: SS → TS (same TMEM)", test_phase3_ss_then_ts, smem3);
// Phase 4: SS then TS with dealloc/realloc
run("Phase 4: SS → dealloc → TS (separate TMEM)", test_phase4_ss_ts_separate_tmem, smem3);
// Phase 5: SS → SS
int smem5 = (4+16 + 2*128*16*2 + 256 + 127) & ~127;
run("Phase 5: SS → SS", test_phase5_ss_ss, smem5);
// Phase 6: TS → TS
int smem6 = (4+16 + 16*16*2 + 256 + 127) & ~127;
run("Phase 6: TS → TS", test_phase6_ts_ts, smem6);
return 0;
}