test: MMA→4-warp read using proven fmha_common+umma_desc infra

This commit is contained in:
2026-05-28 23:05:29 +00:00
parent e1a708a187
commit fee022a485

View File

@@ -1,138 +1,176 @@
/**
* Test: Can we do 16x256b.x1 LOADS multiple times without crashing?
* Uses uint32_t registers (not float) per the working test_tmem_minimal.cu pattern.
* MINIMAL test: After UMMA (QK GEMM), can 4 warps read different rows?
*
* Uses EXACTLY the same SMEM loading and MMA pattern as fmha_6warp_multihead.cuh
* (which is proven to work for T=1 decode).
*
* Then has 4 warps each do 32x32b.x8 read and report what they see.
*/
#include <cuda_runtime.h>
#include <cstdio>
#include <cstdint>
#include <cstring>
#include "dsv4/kernels/attention/fmha_common.cuh"
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
using bf16_t = unsigned short;
using namespace dsv4::kernels::attention;
__device__ __forceinline__ bf16_t f32_to_bf16(float f) {
bf16_t h; asm("cvt.rn.bf16.f32 %0, %1;" : "=h"(h) : "f"(f)); return h;
}
static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; }
__device__ void tmem_alloc(uint32_t smem_ptr, int num_cols) {
asm volatile("tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;"
:: "r"(smem_ptr), "r"(num_cols));
}
__device__ void tmem_dealloc(uint32_t tmem_ptr, int num_cols) {
asm volatile("tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1;"
:: "r"(tmem_ptr), "r"(num_cols));
}
constexpr int HD = 64;
constexpr int SK = 128;
constexpr int NKT_QK = HD / MMA_K_BF16; // 4
constexpr int TILE_SZ = 128 * MMA_K_BF16; // 2048
constexpr int TMEM_N = 128;
constexpr int CORES_MN = 16;
__global__ void __launch_bounds__(32)
test_16x256b_loads(float* results) {
const int lane = threadIdx.x;
const int TMEM_N = 128;
__global__ void __launch_bounds__(192)
test_mma_rows(float* results) {
const int tid = threadIdx.x;
const int wid = tid / 32;
const int lane = tid % 32;
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
float* sRowMax = (float*)(sbuf + 8);
bf16_t* sQ0 = (bf16_t*)(((uintptr_t)(sRowMax + 128) + 127) & ~(uintptr_t)127);
bf16_t* sK0 = sQ0 + TILE_SZ;
uint32_t sp = __cvta_generic_to_shared(sTmemBase);
tmem_alloc(sp, TMEM_N);
__syncwarp();
// TMEM alloc
if (wid == 4) {
uint32_t sp = __cvta_generic_to_shared(sTmemBase);
tmem_alloc(sp, TMEM_N);
}
__syncthreads();
uint32_t tb = *sTmemBase;
// Write data via 32x32b.x8 (known working)
{
float vals[8];
for (int c = 0; c < 8; c++) vals[c] = (float)(lane * 10 + c);
uint32_t ivals[8];
for (int c = 0; c < 8; c++) memcpy(&ivals[c], &vals[c], 4);
// Build GMEM data and copy to device... can't in-kernel.
// Instead, build Q and K directly in SMEM using the load warp.
// Q: (T, HD) where T=128 (full tile), Q[r, d] = (r+1) for all d
// K: (128, HD), K[r, 0] = 1.0, rest zero
// Then S[row, col] = Q[row, 0]*K[col, 0] * 4_K_tiles = (row+1)*1.0*4 = 4*(row+1)
// (because Q[row, d]=row+1 for all d and K[col, 0]=1.0, and there are 4 K-tiles
// but K only has data in d=0 which is in K-tile 0, so S[row, col] = (row+1)*1.0 for col<128)
asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};"
:: "r"(tb + 0),
"r"(ivals[0]), "r"(ivals[1]), "r"(ivals[2]), "r"(ivals[3]),
"r"(ivals[4]), "r"(ivals[5]), "r"(ivals[6]), "r"(ivals[7]));
asm volatile("tcgen05.wait::st.sync.aligned;" ::: "memory");
// Actually simpler: Q[row, 0] = row+1, rest zero. K[col, 0] = 1.0, rest zero.
// Then S[row, col] = (row+1) * 1.0 = row+1 for each (row, col).
for (int kt = 0; kt < NKT_QK; kt++) {
// Load Q and K for this K-tile
if (wid == 5) {
for (int i = lane; i < TILE_SZ; i += 32) { sQ0[i] = 0; sK0[i] = 0; }
// Q: only first K-tile (kt==0) has data at d=0
if (kt == 0) {
for (int r = 0; r < 128; r++) {
int d = 0;
int ck = d / 8, lc = d % 8;
int core_mn = r / 8, local_r = r % 8;
sQ0[ck * CORES_MN * 64 + core_mn * 64 + local_r * 8 + lc] = f32_to_bf16((float)(r + 1));
}
// K: only d=0 has data
for (int r = 0; r < SK; r++) {
int d = 0;
int ck = d / 8, lc = d % 8;
int tmn = r / 8, lr = r % 8;
sK0[ck * CORES_MN * 64 + tmn * 64 + lr * 8 + lc] = f32_to_bf16(1.0f);
}
}
}
__syncthreads();
// MMA
if (wid == 4) {
uint32_t idesc = make_idesc(128, 128);
uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), 128);
uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK0), 128);
if (tid == 128) umma_ss_f16(tb, dq, dk, idesc, kt > 0);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
}
__syncthreads();
}
// 16x256b.x1 loads using uint32_t registers (matching test_tmem_minimal.cu)
int load_count = 0;
// Extra sync for TMEM visibility
asm volatile("fence.sc.gpu;" ::: "memory");
__syncthreads();
// Load column 0
{
uint32_t u0, u1, u2, u3;
asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0, %1, %2, %3}, [%4];"
: "=r"(u0), "=r"(u1), "=r"(u2), "=r"(u3) : "r"(tb + 0));
// Now read TMEM with 32x32b.x8 from each of the 4 softmax warps
// Read column group 0 (columns 0-7)
if (wid < 4) {
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
: "r"(tb + 0));
asm volatile("tcgen05.wait::ld.sync.aligned;");
load_count++;
float v0, v1, v2, v3;
memcpy(&v0, &u0, 4); memcpy(&v1, &u1, 4);
memcpy(&v2, &u2, 4); memcpy(&v3, &u3, 4);
if (lane == 0) { results[0]=v0; results[1]=v1; results[2]=v2; results[3]=v3; }
if (lane == 1) { results[4]=v0; results[5]=v1; results[6]=v2; results[7]=v3; }
// Lane 0 from each warp: should see row 0's col 0 value
// Expected: S[0, 0] = 1.0 (row 0+1)
// But which row does each warp see?
if (lane == 0) results[wid] = tmp[0]; // col 0
if (lane == 0) results[4 + wid] = tmp[1]; // col 1
// Lane 1: row 1
if (lane == 1) results[8 + wid] = tmp[0];
// Lane 15: row 15
if (lane == 15) results[12 + wid] = tmp[0];
// Lane 31: row 31
if (lane == 31) results[16 + wid] = tmp[0];
}
__syncthreads();
// Load column 1 (2nd load — does it crash?)
{
uint32_t u0, u1, u2, u3;
asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0, %1, %2, %3}, [%4];"
: "=r"(u0), "=r"(u1), "=r"(u2), "=r"(u3) : "r"(tb + 4)); // next 4 columns
asm volatile("tcgen05.wait::ld.sync.aligned;");
load_count++;
float v0, v1, v2, v3;
memcpy(&v0, &u0, 4); memcpy(&v1, &u1, 4);
memcpy(&v2, &u2, 4); memcpy(&v3, &u3, 4);
if (lane == 0) { results[8]=v0; results[9]=v1; results[10]=v2; results[11]=v3; }
}
// Load column 8 (3rd load)
{
uint32_t u0, u1, u2, u3;
asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0, %1, %2, %3}, [%4];"
: "=r"(u0), "=r"(u1), "=r"(u2), "=r"(u3) : "r"(tb + 32)); // columns 8-11
asm volatile("tcgen05.wait::ld.sync.aligned;");
load_count++;
float v0, v1, v2, v3;
memcpy(&v0, &u0, 4); memcpy(&v1, &u1, 4);
memcpy(&v2, &u2, 4); memcpy(&v3, &u3, 4);
if (lane == 0) { results[12]=v0; results[13]=v1; results[14]=v2; results[15]=v3; }
}
if (lane == 0) results[16] = (float)load_count;
tmem_dealloc(tb, TMEM_N);
if (wid == 4) tmem_dealloc(tb, TMEM_N);
}
int main() {
printf("16x256b.x1 multiple LOAD test\n");
printf("==============================\n\n");
printf("MMA → 4-warp 32x32b.x8 read\n");
printf("=============================\n");
printf("Q[r,0] = r+1, K[c,0] = 1.0 → S[r,c] = r+1\n");
printf("Lane 0 in each warp should see row 0's value (1.0) if all warps see same rows,\n");
printf("or different values if warps see different rows.\n\n");
float* d_r;
cudaMalloc(&d_r, 32 * sizeof(float));
cudaMemset(d_r, 0, 32 * sizeof(float));
test_16x256b_loads<<<1, 32, 256>>>(d_r);
// SMEM: sbuf(8) + sRowMax(512) + align(128) + sQ0(4096) + sK0(4096) + slack(256) = 9000
size_t smem_off = 8 + 128*4;
smem_off = ((smem_off + 127) & ~(size_t)127);
smem_off += TILE_SZ * 2 + 256;
int smem = (int)smem_off;
test_mma_rows<<<1, 192, smem>>>(d_r);
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) {
printf("CUDA ERROR: %s\n", cudaGetErrorString(err));
printf("16x256b.x1 loads CRASHED\n");
cudaFree(d_r);
return 1;
printf("Kernel crashed!\n");
} else {
float h[32];
cudaMemcpy(h, d_r, 32 * sizeof(float), cudaMemcpyDeviceToHost);
printf("Lane 0, col 0 (S[0,0]=1.0 if row 0): w0=%.1f w1=%.1f w2=%.1f w3=%.1f\n",
h[0], h[1], h[2], h[3]);
printf("Lane 0, col 1 (S[0,1]=1.0 if row 0): w0=%.1f w1=%.1f w2=%.1f w3=%.1f\n",
h[4], h[5], h[6], h[7]);
printf("Lane 1, col 0 (S[1,0]=2.0 if row 1): w0=%.1f w1=%.1f w2=%.1f w3=%.1f\n",
h[8], h[9], h[10], h[11]);
printf("Lane 15, col 0 (S[15,0]=16.0): w0=%.1f w1=%.1f w2=%.1f w3=%.1f\n",
h[12], h[13], h[14], h[15]);
printf("Lane 31, col 0 (S[31,0]=32.0): w0=%.1f w1=%.1f w2=%.1f w3=%.1f\n",
h[16], h[17], h[18], h[19]);
// Check if warps see different data
bool all_same = (h[0] == h[1]) && (h[1] == h[2]) && (h[2] == h[3]);
bool w0_correct = (fabsf(h[0] - 1.0f) < 0.01f);
if (w0_correct && all_same) {
printf("\nAll warps see the SAME 32 rows (rows 0-31). Warp 0 data is correct.\n");
printf("For T>32, we need an alternative to 32x32b.x8 for rows 32-127.\n");
} else if (!all_same) {
printf("\nWarps see DIFFERENT rows! Multi-warp softmax is possible!\n");
} else {
printf("\nUnexpected: warp 0 data incorrect (%.1f, expected 1.0)\n", h[0]);
}
}
float h[32];
cudaMemcpy(h, d_r, 32 * sizeof(float), cudaMemcpyDeviceToHost);
printf("Load count: %d (3 loads completed = no crash)\n\n", (int)h[16]);
printf("Col 0, lane 0 (expect rows 0-3: 0,10,20,30): %.1f %.1f %.1f %.1f\n", h[0], h[1], h[2], h[3]);
printf("Col 0, lane 1 (expect rows 4-7: 40,50,60,70): %.1f %.1f %.1f %.1f\n", h[4], h[5], h[6], h[7]);
printf("Col 1, lane 0 (expect rows 0-3 at col 1: 1,11,21,31): %.1f %.1f %.1f %.1f\n", h[8], h[9], h[10], h[11]);
printf("Col 8, lane 0 (expect row 0 col 8 = 8, row 1 col 8 = 18, etc): %.1f %.1f %.1f %.1f\n", h[12], h[13], h[14], h[15]);
int pass = (fabsf(h[0]-0.0f)<0.01f) && (fabsf(h[1]-10.0f)<0.01f) &&
(fabsf(h[2]-20.0f)<0.01f) && (fabsf(h[3]-30.0f)<0.01f) && (int)h[16]==3;
printf("\nResult: %s\n", pass ? "16x256b.x1 LOADS work multiple times!" : "Data mismatch or crash");
cudaFree(d_r);
return pass ? 0 : 1;
return 0;
}