diff --git a/tests/unit/test_sw128_qk.cu b/tests/unit/test_sw128_qk.cu new file mode 100644 index 00000000..0e771237 --- /dev/null +++ b/tests/unit/test_sw128_qk.cu @@ -0,0 +1,172 @@ +/** + * Test SW128 UMMA descriptor for QK GEMM. + * If this works, we can switch from NONE to SW128 and then add TMA loads. + * + * SW128 swizzle: TMA produces data in a format that MMA can read directly. + * The SMEM layout is different from the NONE (canonical) layout. + * + * Key difference: + * NONE: SMEM is written manually in canonical core-matrix order + * SW128: SMEM is written by TMA in 128-byte swizzled order + */ + +#include +#include +#include +#include + +#include "dsv4/kernels/attention/fmha_common.cuh" +#include "dsv4/kernels/attention/fmha_umma_desc.cuh" + +using namespace dsv4::kernels::attention; + +static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); } +static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; } + +constexpr int HD = 16, SK = 128, BLOCK_MN = 128; +constexpr int NKT_QK = HD / MMA_K_BF16; +constexpr int TILE_SZ = BLOCK_MN * MMA_K_BF16; + +// The SW128 SMEM layout is what TMA produces when loading a (128, 16) BF16 tile. +// For a 128×16 BF16 matrix with 128-byte swizzle: +// The swizzle permutes the 128-byte sectors to avoid bank conflicts. +// We need to write data in the SW128 layout for the MMA to read correctly. + +// From CUTLASS and the PTX spec, the SW128 layout for a (128, 16) BF16 tile: +// Each 128-byte row is stored at address: row * 16 * 2 = row * 32 bytes +// With SW128, the row address is XOR-swizzled with bits from the column index. +// The exact formula depends on the swizzle pattern. + +// For now, let's test: use the NONE (canonical) layout we know works, +// but with the SW128 descriptor. If the results are wrong, we know the +// SMEM layout must change for SW128. + +__global__ void __launch_bounds__(128) +test_sw128_qk(const bf16_t* q, const bf16_t* k, + float* o_scalar, float scale) +{ + const int tid = threadIdx.x, wid = tid / 32, lane = tid % 32; + + extern __shared__ char sbuf[]; + uint32_t* sTmemBase = (uint32_t*)sbuf; + bf16_t* sQ0 = (bf16_t*)(((uintptr_t)(sbuf + 4) + 15) & ~(uintptr_t)15); + bf16_t* sK0 = sQ0 + TILE_SZ; + + // Load Q and K in CANONICAL layout (same as working NONE path) + for (int i = tid; i < TILE_SZ; i += 128) sQ0[i] = 0; + for (int d = tid; d < MMA_K_BF16; d += 128) { + int ck = d / 8, lc = d % 8; + sQ0[ck * 16 * 64 + lc] = q[d]; + } + for (int i = tid; i < TILE_SZ; i += 128) sK0[i] = 0; + for (int r = 0; r < SK; r++) { + for (int d = tid; d < MMA_K_BF16; d += 128) { + int ck = d / 8, lc = d % 8; + int tmn = r / 8, lr = r % 8; + sK0[ck * 16 * 64 + tmn * 64 + lr * 8 + lc] = k[r * HD + d]; + } + } + __syncthreads(); + + // TMEM alloc + if (wid == 1) tmem_alloc(__cvta_generic_to_shared(sTmemBase), 128); + __syncthreads(); + uint32_t tb = *sTmemBase; + + // QK GEMM with SW128 descriptor + { + // Try SW128 descriptor for Q + uint64_t dq_sw128 = make_umma_desc_kmajor_sw128(__cvta_generic_to_shared(sQ0), BLOCK_MN); + uint64_t dk_sw128 = make_umma_desc_kmajor_sw128(__cvta_generic_to_shared(sK0), BLOCK_MN); + uint32_t idesc = make_idesc(BLOCK_MN, BLOCK_MN); + if (tid == 0) umma_ss_f16(tb, dq_sw128, dk_sw128, idesc, false); + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + __syncthreads(); + } + + // Read S from TMEM (row 0 only) + if (wid == 0) { + float s_vals[SK], row_max = -INFINITY; + for (int n = 0; n < SK / 8; n++) { + float tmp[8]; + asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];" + : "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]), + "=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7]) + : "r"(tb + n*8)); + asm volatile("tcgen05.wait::ld.sync.aligned;"); + if (lane == 0) for (int c=0;c<8;c++) { + s_vals[n*8+c] = tmp[c] * scale; + row_max = fmaxf(row_max, tmp[c] * scale); + } + } + row_max = wmax(row_max); + float row_sum = 0.0f; + if (lane == 0) for (int j=0;j>>(d_q, d_k, d_o_scalar, SCALE); + + cudaError_t err = cudaDeviceSynchronize(); + if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; } + + cudaFree(d_q); cudaFree(d_k); cudaFree(d_o_scalar); + free(h_q); free(h_k); + return 0; +}