diff --git a/tests/unit/test_qk_pv_layout.cu b/tests/unit/test_qk_pv_layout.cu new file mode 100644 index 00000000..131c0282 --- /dev/null +++ b/tests/unit/test_qk_pv_layout.cu @@ -0,0 +1,175 @@ +/** + * Debug test: QK (SS) → PV (TS) without softmax. + * This tests if the TS MMA can correctly read the SS MMA's output from TMEM. + * The TMEM layout of S (written by SS MMA) should be compatible with TS MMA's A format. + * If this produces reasonable results (not garbage), the layout is compatible. + * If garbage, the TMEM layouts don't match. + */ + +#include +#include +#include +#include +#include + +#include "dsv4/kernels/attention/fmha_common.cuh" +#include "dsv4/kernels/attention/fmha_umma_desc.cuh" + +using namespace dsv4::kernels::attention; + +static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); } +static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; } + +constexpr int HD = 16, SK = 128, BLOCK_MN = 128; +constexpr int NKT_QK = HD / MMA_K_BF16; // 1 +constexpr int NKT_PV = SK / MMA_K_BF16; // 8 +constexpr int TMEM_P = 128; +constexpr int TMEM_N = 256; +constexpr int TILE_SZ = BLOCK_MN * MMA_K_BF16; + +__global__ void __launch_bounds__(128) +test_qk_pv_no_softmax(const bf16_t* __restrict__ q, const bf16_t* __restrict__ k, + const bf16_t* __restrict__ v, float* __restrict__ o_mma, + float* __restrict__ o_scalar, float scale) +{ + const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32; + + extern __shared__ char sbuf[]; + uint32_t* sTmemBase = (uint32_t*)sbuf; + bf16_t* sQ0 = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15); + bf16_t* sK0 = sQ0 + TILE_SZ; + bf16_t* sV = (bf16_t*)(((uintptr_t)(sK0 + TILE_SZ) + 127) & ~(uintptr_t)127); + + write_q_to_smem(sQ0, q); + write_k_to_smem(sK0, k); + + // Load V K-tiles + for (int kt = 0; kt < NKT_PV; kt++) { + bf16_t* sv = sV + kt * 256; + for (int i = tid; i < 256; i += 128) sv[i] = 0; + for (int d = tid; d < HD; d += 128) { + for (int lr = 0; lr < MMA_K_BF16; lr++) { + int r = kt * MMA_K_BF16 + lr; + int ck = d / 8, lc = d % 8; + int tmn = lr / 8, llr = lr % 8; + int dst_idx = ck * 2 * 64 + tmn * 64 + llr * 8 + lc; + sv[dst_idx] = v[d * SK + r]; + } + } + } + __syncthreads(); + + // TMEM alloc: 256 cols. 0-127 = S. 128-143 = O. + if (wid == 1) tmem_alloc(__cvta_generic_to_shared(sTmemBase), TMEM_N); + __syncthreads(); + uint32_t tb = *sTmemBase; + uint32_t tb_o = tb + TMEM_P; + + // QK GEMM + { + uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), BLOCK_MN); + uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK0), BLOCK_MN); + uint32_t idesc_qk = make_idesc(BLOCK_MN, BLOCK_MN); + for (int kt = 0; kt < NKT_QK; kt++) { + if (tid == 0) umma_ss_f16(tb, dq, dk, idesc_qk, kt > 0); + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + __syncthreads(); + } + } + + // NO SOFTMAX — use S directly as P (wrong mathematically, but tests TMEM layout) + + // PV GEMM: S × V → O + { + uint32_t idesc_pv = make_idesc(BLOCK_MN, HD); + for (int kt = 0; kt < NKT_PV; kt++) { + uint32_t tmem_a = tb + kt * MMA_K_BF16; + bf16_t* sv = sV + kt * 256; + uint64_t dv = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sv), 16); + if (tid == 0) umma_ts_f16(tb_o, tmem_a, dv, idesc_pv, kt > 0); + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + __syncthreads(); + } + } + + // Read O from TMEM (row 0 only for T=1) + if (wid == 0) { + float o_vals[HD]; + for (int n = 0; n < HD / 8; n++) { + float tmp[8]; + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" + : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), + "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) + : "r"(tb_o + n*8)); + asm volatile("tcgen05.wait::ld.sync.aligned;"); + if (lane == 0) for (int c=0;c<8;c++) o_vals[n*8+c] = tmp[c] * 2.0f; + } + if (lane == 0) for (int d=0;d>>(d_q, d_k, d_v, d_o_mma, d_o_scalar, SCALE); + + cudaError_t err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; } + + cudaMemcpy(h_o_mma, d_o_mma, HD*sizeof(float), cudaMemcpyDeviceToHost); + cudaMemcpy(h_o_scalar, d_o_scalar, HD*sizeof(float), cudaMemcpyDeviceToHost); + + printf("O[0..15] MMA: "); for(int d=0;d 0.999f ? "PASSED" : "FAILED"); + + cudaFree(d_q); cudaFree(d_k); cudaFree(d_v); cudaFree(d_o_mma); cudaFree(d_o_scalar); + free(h_q); free(h_k); free(h_v); free(h_o_mma); free(h_o_scalar); + return cos_sim > 0.999f ? 0 : 1; +}