From 38d7bcd7767885ffeb2f651c3fd8f3caffef3d5b Mon Sep 17 00:00:00 2001 From: biondizzle Date: Thu, 28 May 2026 13:03:06 +0000 Subject: [PATCH] test: HD=16 FMHA softmax only (skip PV for now) --- tests/unit/test_fmha_hd16.cu | 49 +++++------------------------------- 1 file changed, 6 insertions(+), 43 deletions(-) diff --git a/tests/unit/test_fmha_hd16.cu b/tests/unit/test_fmha_hd16.cu index 4df5dd87..6a23a8b0 100644 --- a/tests/unit/test_fmha_hd16.cu +++ b/tests/unit/test_fmha_hd16.cu @@ -134,54 +134,17 @@ test_fmha_hd16(const bf16_t* q, const bf16_t* k, const bf16_t* v, } __syncthreads(); - // ================================================================ - // STEP 3: PV GEMM — P (TMEM) × V (SMEM) → O (TMEM) - // tcgen05.mma TS: A from TMEM, B from SMEM, result to TMEM. - // For each PV K-tile kt (K=16): - // A = P[:, 16*kt:16*kt+16] → TMEM starting at tb + 16*kt - // B = V[16*kt:16*kt+16, :] → SMEM at sV0 + kt * V_TILE_SZ - // C = O (128, 16) → TMEM starting at tb (overwrite P's first 16 cols) - // accumulate across K-tiles - // ================================================================ - // For HD=16, O is (128, 16) in TMEM — 16 TMEM columns. - // But we allocated 128 TMEM columns (from QK). O needs 16 columns. - // We can reuse the same TMEM region (tb) for O since P is being consumed. - // The O output TMEM base can be tb (reusing the same allocation). - // - // idesc for PV: M=128, N=16 (O is 128×16) - // MMA_M = 128/16 = 8, MMA_N = 16/8 = 2 - uint32_t idesc_pv = make_idesc(BLOCK_MN, HD); - - for (int kt = 0; kt < VKT; kt++) { - bf16_t* sv = sV0 + kt * V_TILE_SZ; - uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sv), MMA_K_BF16); - // P's K-tile in TMEM: columns [16*kt, 16*kt+15] - // For tcgen05.mma TS, the A operand is a TMEM address. - // The hardware reads 16 consecutive TMEM columns starting from tmem_a. - uint32_t tmem_a = tb + kt * MMA_K_BF16; // 16 columns of P - - if (tid == 0) umma_ts_f16(tb, tmem_a, dv, idesc_pv, kt > 0); - asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); - __syncthreads(); - } - asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); - __syncthreads(); - - // ================================================================ - // STEP 4: Epilogue — O (TMEM) → normalize → BF16 → GMEM - // O is (128, 16) in TMEM, only row 0 has data. - // Read row 0 from TMEM, write to GMEM. - // ================================================================ + // SKIP PV for now — just verify softmax + // Read P from TMEM and write to output if (wid == 0) { - // O is in the first 16 TMEM columns (cols 0..15) - float o_vals[HD]; - for (int n = 0; n < HD / 8; n++) { // 2 iterations for HD=16 + float p_vals[SK]; + for (int n = 0; n < SK / 8; n++) { float tmp[8]; asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) : "r"(tb + n*8)); asm volatile("tcgen05.wait::ld.sync.aligned;"); - if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c]; + if (lane == 0) for (int c=0;c<8;c++) p_vals[n*8+c] = tmp[c]; } - if (lane == 0) for (int d=0;d