test: MMA→4-warp read using proven fmha_common+umma_desc infra
This commit is contained in:
@@ -1,138 +1,176 @@
|
||||
/**
|
||||
* Test: Can we do 16x256b.x1 LOADS multiple times without crashing?
|
||||
* Uses uint32_t registers (not float) per the working test_tmem_minimal.cu pattern.
|
||||
* MINIMAL test: After UMMA (QK GEMM), can 4 warps read different rows?
|
||||
*
|
||||
* Uses EXACTLY the same SMEM loading and MMA pattern as fmha_6warp_multihead.cuh
|
||||
* (which is proven to work for T=1 decode).
|
||||
*
|
||||
* Then has 4 warps each do 32x32b.x8 read and report what they see.
|
||||
*/
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cstdio>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
#include "dsv4/kernels/attention/fmha_common.cuh"
|
||||
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
|
||||
|
||||
using bf16_t = unsigned short;
|
||||
using namespace dsv4::kernels::attention;
|
||||
|
||||
__device__ __forceinline__ bf16_t f32_to_bf16(float f) {
|
||||
bf16_t h; asm("cvt.rn.bf16.f32 %0, %1;" : "=h"(h) : "f"(f)); return h;
|
||||
}
|
||||
static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
|
||||
static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; }
|
||||
|
||||
__device__ void tmem_alloc(uint32_t smem_ptr, int num_cols) {
|
||||
asm volatile("tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;"
|
||||
:: "r"(smem_ptr), "r"(num_cols));
|
||||
}
|
||||
__device__ void tmem_dealloc(uint32_t tmem_ptr, int num_cols) {
|
||||
asm volatile("tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1;"
|
||||
:: "r"(tmem_ptr), "r"(num_cols));
|
||||
}
|
||||
constexpr int HD = 64;
|
||||
constexpr int SK = 128;
|
||||
constexpr int NKT_QK = HD / MMA_K_BF16; // 4
|
||||
constexpr int TILE_SZ = 128 * MMA_K_BF16; // 2048
|
||||
constexpr int TMEM_N = 128;
|
||||
constexpr int CORES_MN = 16;
|
||||
|
||||
__global__ void __launch_bounds__(32)
|
||||
test_16x256b_loads(float* results) {
|
||||
const int lane = threadIdx.x;
|
||||
const int TMEM_N = 128;
|
||||
__global__ void __launch_bounds__(192)
|
||||
test_mma_rows(float* results) {
|
||||
const int tid = threadIdx.x;
|
||||
const int wid = tid / 32;
|
||||
const int lane = tid % 32;
|
||||
|
||||
extern __shared__ char sbuf[];
|
||||
uint32_t* sTmemBase = (uint32_t*)sbuf;
|
||||
float* sRowMax = (float*)(sbuf + 8);
|
||||
bf16_t* sQ0 = (bf16_t*)(((uintptr_t)(sRowMax + 128) + 127) & ~(uintptr_t)127);
|
||||
bf16_t* sK0 = sQ0 + TILE_SZ;
|
||||
|
||||
uint32_t sp = __cvta_generic_to_shared(sTmemBase);
|
||||
tmem_alloc(sp, TMEM_N);
|
||||
__syncwarp();
|
||||
// TMEM alloc
|
||||
if (wid == 4) {
|
||||
uint32_t sp = __cvta_generic_to_shared(sTmemBase);
|
||||
tmem_alloc(sp, TMEM_N);
|
||||
}
|
||||
__syncthreads();
|
||||
uint32_t tb = *sTmemBase;
|
||||
|
||||
// Write data via 32x32b.x8 (known working)
|
||||
{
|
||||
float vals[8];
|
||||
for (int c = 0; c < 8; c++) vals[c] = (float)(lane * 10 + c);
|
||||
uint32_t ivals[8];
|
||||
for (int c = 0; c < 8; c++) memcpy(&ivals[c], &vals[c], 4);
|
||||
// Build GMEM data and copy to device... can't in-kernel.
|
||||
// Instead, build Q and K directly in SMEM using the load warp.
|
||||
// Q: (T, HD) where T=128 (full tile), Q[r, d] = (r+1) for all d
|
||||
// K: (128, HD), K[r, 0] = 1.0, rest zero
|
||||
// Then S[row, col] = Q[row, 0]*K[col, 0] * 4_K_tiles = (row+1)*1.0*4 = 4*(row+1)
|
||||
// (because Q[row, d]=row+1 for all d and K[col, 0]=1.0, and there are 4 K-tiles
|
||||
// but K only has data in d=0 which is in K-tile 0, so S[row, col] = (row+1)*1.0 for col<128)
|
||||
|
||||
asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};"
|
||||
:: "r"(tb + 0),
|
||||
"r"(ivals[0]), "r"(ivals[1]), "r"(ivals[2]), "r"(ivals[3]),
|
||||
"r"(ivals[4]), "r"(ivals[5]), "r"(ivals[6]), "r"(ivals[7]));
|
||||
asm volatile("tcgen05.wait::st.sync.aligned;" ::: "memory");
|
||||
// Actually simpler: Q[row, 0] = row+1, rest zero. K[col, 0] = 1.0, rest zero.
|
||||
// Then S[row, col] = (row+1) * 1.0 = row+1 for each (row, col).
|
||||
|
||||
for (int kt = 0; kt < NKT_QK; kt++) {
|
||||
// Load Q and K for this K-tile
|
||||
if (wid == 5) {
|
||||
for (int i = lane; i < TILE_SZ; i += 32) { sQ0[i] = 0; sK0[i] = 0; }
|
||||
|
||||
// Q: only first K-tile (kt==0) has data at d=0
|
||||
if (kt == 0) {
|
||||
for (int r = 0; r < 128; r++) {
|
||||
int d = 0;
|
||||
int ck = d / 8, lc = d % 8;
|
||||
int core_mn = r / 8, local_r = r % 8;
|
||||
sQ0[ck * CORES_MN * 64 + core_mn * 64 + local_r * 8 + lc] = f32_to_bf16((float)(r + 1));
|
||||
}
|
||||
// K: only d=0 has data
|
||||
for (int r = 0; r < SK; r++) {
|
||||
int d = 0;
|
||||
int ck = d / 8, lc = d % 8;
|
||||
int tmn = r / 8, lr = r % 8;
|
||||
sK0[ck * CORES_MN * 64 + tmn * 64 + lr * 8 + lc] = f32_to_bf16(1.0f);
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// MMA
|
||||
if (wid == 4) {
|
||||
uint32_t idesc = make_idesc(128, 128);
|
||||
uint64_t dq = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sQ0), 128);
|
||||
uint64_t dk = make_umma_desc_kmajor_none(__cvta_generic_to_shared(sK0), 128);
|
||||
if (tid == 128) umma_ss_f16(tb, dq, dk, idesc, kt > 0);
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// 16x256b.x1 loads using uint32_t registers (matching test_tmem_minimal.cu)
|
||||
int load_count = 0;
|
||||
// Extra sync for TMEM visibility
|
||||
asm volatile("fence.sc.gpu;" ::: "memory");
|
||||
__syncthreads();
|
||||
|
||||
// Load column 0
|
||||
{
|
||||
uint32_t u0, u1, u2, u3;
|
||||
asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(u0), "=r"(u1), "=r"(u2), "=r"(u3) : "r"(tb + 0));
|
||||
// Now read TMEM with 32x32b.x8 from each of the 4 softmax warps
|
||||
// Read column group 0 (columns 0-7)
|
||||
if (wid < 4) {
|
||||
float tmp[8];
|
||||
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
|
||||
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
|
||||
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
|
||||
: "r"(tb + 0));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
load_count++;
|
||||
|
||||
float v0, v1, v2, v3;
|
||||
memcpy(&v0, &u0, 4); memcpy(&v1, &u1, 4);
|
||||
memcpy(&v2, &u2, 4); memcpy(&v3, &u3, 4);
|
||||
|
||||
if (lane == 0) { results[0]=v0; results[1]=v1; results[2]=v2; results[3]=v3; }
|
||||
if (lane == 1) { results[4]=v0; results[5]=v1; results[6]=v2; results[7]=v3; }
|
||||
// Lane 0 from each warp: should see row 0's col 0 value
|
||||
// Expected: S[0, 0] = 1.0 (row 0+1)
|
||||
// But which row does each warp see?
|
||||
if (lane == 0) results[wid] = tmp[0]; // col 0
|
||||
if (lane == 0) results[4 + wid] = tmp[1]; // col 1
|
||||
// Lane 1: row 1
|
||||
if (lane == 1) results[8 + wid] = tmp[0];
|
||||
// Lane 15: row 15
|
||||
if (lane == 15) results[12 + wid] = tmp[0];
|
||||
// Lane 31: row 31
|
||||
if (lane == 31) results[16 + wid] = tmp[0];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Load column 1 (2nd load — does it crash?)
|
||||
{
|
||||
uint32_t u0, u1, u2, u3;
|
||||
asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(u0), "=r"(u1), "=r"(u2), "=r"(u3) : "r"(tb + 4)); // next 4 columns
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
load_count++;
|
||||
|
||||
float v0, v1, v2, v3;
|
||||
memcpy(&v0, &u0, 4); memcpy(&v1, &u1, 4);
|
||||
memcpy(&v2, &u2, 4); memcpy(&v3, &u3, 4);
|
||||
|
||||
if (lane == 0) { results[8]=v0; results[9]=v1; results[10]=v2; results[11]=v3; }
|
||||
}
|
||||
|
||||
// Load column 8 (3rd load)
|
||||
{
|
||||
uint32_t u0, u1, u2, u3;
|
||||
asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(u0), "=r"(u1), "=r"(u2), "=r"(u3) : "r"(tb + 32)); // columns 8-11
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
load_count++;
|
||||
|
||||
float v0, v1, v2, v3;
|
||||
memcpy(&v0, &u0, 4); memcpy(&v1, &u1, 4);
|
||||
memcpy(&v2, &u2, 4); memcpy(&v3, &u3, 4);
|
||||
|
||||
if (lane == 0) { results[12]=v0; results[13]=v1; results[14]=v2; results[15]=v3; }
|
||||
}
|
||||
|
||||
if (lane == 0) results[16] = (float)load_count;
|
||||
tmem_dealloc(tb, TMEM_N);
|
||||
if (wid == 4) tmem_dealloc(tb, TMEM_N);
|
||||
}
|
||||
|
||||
int main() {
|
||||
printf("16x256b.x1 multiple LOAD test\n");
|
||||
printf("==============================\n\n");
|
||||
printf("MMA → 4-warp 32x32b.x8 read\n");
|
||||
printf("=============================\n");
|
||||
printf("Q[r,0] = r+1, K[c,0] = 1.0 → S[r,c] = r+1\n");
|
||||
printf("Lane 0 in each warp should see row 0's value (1.0) if all warps see same rows,\n");
|
||||
printf("or different values if warps see different rows.\n\n");
|
||||
|
||||
float* d_r;
|
||||
cudaMalloc(&d_r, 32 * sizeof(float));
|
||||
cudaMemset(d_r, 0, 32 * sizeof(float));
|
||||
test_16x256b_loads<<<1, 32, 256>>>(d_r);
|
||||
|
||||
// SMEM: sbuf(8) + sRowMax(512) + align(128) + sQ0(4096) + sK0(4096) + slack(256) = 9000
|
||||
size_t smem_off = 8 + 128*4;
|
||||
smem_off = ((smem_off + 127) & ~(size_t)127);
|
||||
smem_off += TILE_SZ * 2 + 256;
|
||||
int smem = (int)smem_off;
|
||||
|
||||
test_mma_rows<<<1, 192, smem>>>(d_r);
|
||||
|
||||
cudaError_t err = cudaDeviceSynchronize();
|
||||
if (err != cudaSuccess) {
|
||||
printf("CUDA ERROR: %s\n", cudaGetErrorString(err));
|
||||
printf("16x256b.x1 loads CRASHED\n");
|
||||
cudaFree(d_r);
|
||||
return 1;
|
||||
printf("Kernel crashed!\n");
|
||||
} else {
|
||||
float h[32];
|
||||
cudaMemcpy(h, d_r, 32 * sizeof(float), cudaMemcpyDeviceToHost);
|
||||
|
||||
printf("Lane 0, col 0 (S[0,0]=1.0 if row 0): w0=%.1f w1=%.1f w2=%.1f w3=%.1f\n",
|
||||
h[0], h[1], h[2], h[3]);
|
||||
printf("Lane 0, col 1 (S[0,1]=1.0 if row 0): w0=%.1f w1=%.1f w2=%.1f w3=%.1f\n",
|
||||
h[4], h[5], h[6], h[7]);
|
||||
printf("Lane 1, col 0 (S[1,0]=2.0 if row 1): w0=%.1f w1=%.1f w2=%.1f w3=%.1f\n",
|
||||
h[8], h[9], h[10], h[11]);
|
||||
printf("Lane 15, col 0 (S[15,0]=16.0): w0=%.1f w1=%.1f w2=%.1f w3=%.1f\n",
|
||||
h[12], h[13], h[14], h[15]);
|
||||
printf("Lane 31, col 0 (S[31,0]=32.0): w0=%.1f w1=%.1f w2=%.1f w3=%.1f\n",
|
||||
h[16], h[17], h[18], h[19]);
|
||||
|
||||
// Check if warps see different data
|
||||
bool all_same = (h[0] == h[1]) && (h[1] == h[2]) && (h[2] == h[3]);
|
||||
bool w0_correct = (fabsf(h[0] - 1.0f) < 0.01f);
|
||||
|
||||
if (w0_correct && all_same) {
|
||||
printf("\nAll warps see the SAME 32 rows (rows 0-31). Warp 0 data is correct.\n");
|
||||
printf("For T>32, we need an alternative to 32x32b.x8 for rows 32-127.\n");
|
||||
} else if (!all_same) {
|
||||
printf("\nWarps see DIFFERENT rows! Multi-warp softmax is possible!\n");
|
||||
} else {
|
||||
printf("\nUnexpected: warp 0 data incorrect (%.1f, expected 1.0)\n", h[0]);
|
||||
}
|
||||
}
|
||||
|
||||
float h[32];
|
||||
cudaMemcpy(h, d_r, 32 * sizeof(float), cudaMemcpyDeviceToHost);
|
||||
|
||||
printf("Load count: %d (3 loads completed = no crash)\n\n", (int)h[16]);
|
||||
|
||||
printf("Col 0, lane 0 (expect rows 0-3: 0,10,20,30): %.1f %.1f %.1f %.1f\n", h[0], h[1], h[2], h[3]);
|
||||
printf("Col 0, lane 1 (expect rows 4-7: 40,50,60,70): %.1f %.1f %.1f %.1f\n", h[4], h[5], h[6], h[7]);
|
||||
printf("Col 1, lane 0 (expect rows 0-3 at col 1: 1,11,21,31): %.1f %.1f %.1f %.1f\n", h[8], h[9], h[10], h[11]);
|
||||
printf("Col 8, lane 0 (expect row 0 col 8 = 8, row 1 col 8 = 18, etc): %.1f %.1f %.1f %.1f\n", h[12], h[13], h[14], h[15]);
|
||||
|
||||
int pass = (fabsf(h[0]-0.0f)<0.01f) && (fabsf(h[1]-10.0f)<0.01f) &&
|
||||
(fabsf(h[2]-20.0f)<0.01f) && (fabsf(h[3]-30.0f)<0.01f) && (int)h[16]==3;
|
||||
printf("\nResult: %s\n", pass ? "16x256b.x1 LOADS work multiple times!" : "Data mismatch or crash");
|
||||
cudaFree(d_r);
|
||||
return pass ? 0 : 1;
|
||||
return 0;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user