test: MMA→4-warp TMEM read — do warps see different rows?

This commit is contained in:
2026-05-28 23:00:27 +00:00
parent 6b0d57074a
commit be45e87891

View File

@@ -1,7 +1,14 @@
/**
* Test: TMEM cross-warp visibility after 32x32b.x8 store.
* Try different synchronization strategies to make TMEM stores
* visible to all warps.
* Test: Can 4 warps read TMEM data written by UMMA (not 32x32b.x8 store)?
*
* The MMA writes to TMEM via the tcgen05.mma instruction, which uses
* the full 128-thread hardware pipeline. This should produce TMEM data
* that's visible to all warps in the CTA.
*
* If this works, then multi-row softmax is possible:
* - QK MMA writes S to TMEM (hardware pipeline)
* - 4 warps read S with 32x32b.x8 (each sees 32 rows)
* - The question is: does each warp see DIFFERENT 32 rows, or the same 32?
*/
#include <cuda_runtime.h>
@@ -14,6 +21,9 @@ using bf16_t = unsigned short;
__device__ __forceinline__ bf16_t f32_to_bf16(float f) {
bf16_t h; asm("cvt.rn.bf16.f32 %0, %1;" : "=h"(h) : "f"(f)); return h;
}
__device__ __forceinline__ float bf16_to_f32(bf16_t h) {
float f; asm("cvt.f32.bf16 %0, %1;" : "=f"(f) : "h"(h)); return f;
}
__device__ void tmem_alloc(uint32_t smem_ptr, int num_cols) {
asm volatile("tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;"
@@ -24,9 +34,42 @@ __device__ void tmem_dealloc(uint32_t tmem_ptr, int num_cols) {
:: "r"(tmem_ptr), "r"(num_cols));
}
// Strategy: warp 0 stores, then ALL warps fence + sync, then read
__global__ void __launch_bounds__(128)
test_tmem_sync(float* results, int strategy) {
__device__ __forceinline__ uint64_t desc_encode(uint64_t byte_val) { return byte_val >> 4; }
__device__ __forceinline__ uint64_t make_umma_desc(uint32_t smem_addr, int block_mn) {
uint64_t desc = 0;
desc |= desc_encode(smem_addr) & 0x3FFF;
desc |= (desc_encode((uint64_t)block_mn * 16) & 0x3FFF) << 16;
desc |= (desc_encode(128ULL) & 0x3FFF) << 32;
desc |= 1ULL << 46;
return desc;
}
__device__ __forceinline__ uint32_t make_idesc(int block_m, int block_n) {
return (1U << 4) | (1U << 7) | (1U << 10) |
((uint32_t)(block_n >> 3) << 17) | ((uint32_t)(block_m >> 4) << 24);
}
__device__ void umma_ss(uint32_t tmem_c, uint64_t da, uint64_t db, uint32_t idesc, bool acc) {
uint32_t sc = acc ? 0x3F800000u : 0u;
asm volatile("{\n\t.reg .pred p;\n\tsetp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, {%5,%6,%7,%8}, p;\n\t}"
:: "r"(tmem_c), "l"(da), "l"(db), "r"(idesc), "r"(sc), "r"(0), "r"(0), "r"(0), "r"(0));
}
// Write (rows, 16) BF16 matrix to SMEM canonical layout
__device__ void write_canonical(bf16_t* dst, const bf16_t* src, int rows, int cols16) {
constexpr int CORES_MN = 16;
for (int i = threadIdx.x; i < 128 * cols16; i += 192) {
int r = i / cols16, c = i % cols16;
if (r < rows && c < cols16) {
int core_mn = r / 8, local_r = r % 8;
int core_k = c / 8, local_c = c % 8;
int dst_idx = core_k * CORES_MN * 64 + core_mn * 64 + local_r * 8 + local_c;
dst[dst_idx] = src[r * cols16 + c];
}
}
}
__global__ void __launch_bounds__(192)
test_mma_4warp_read(float* results) {
const int tid = threadIdx.x;
const int wid = tid / 32;
const int lane = tid % 32;
@@ -34,44 +77,48 @@ test_tmem_sync(float* results, int strategy) {
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
bf16_t* sA = (bf16_t*)(((uintptr_t)(sbuf + 256) + 127) & ~(uintptr_t)127);
bf16_t* sB = sA + 128 * 16;
if (wid == 0) {
// TMEM alloc (warp 4)
if (wid == 4) {
uint32_t sp = __cvta_generic_to_shared(sTmemBase);
tmem_alloc(sp, TMEM_N);
}
__syncthreads();
uint32_t tb = *sTmemBase;
// Warp 0 stores data
if (wid == 0) {
float vals[8];
for (int c = 0; c < 8; c++) vals[c] = (float)(lane * 10 + c);
uint32_t ivals[8];
for (int c = 0; c < 8; c++) memcpy(&ivals[c], &vals[c], 4);
// Build A such that A[row, 0] = row+1 (row 0 = 1.0, row 1 = 2.0, etc.)
// This way we can identify which row each lane is reading
if (wid == 5) { // load warp
bf16_t tmp[128 * 16];
for (int i = 0; i < 128 * 16; i++) tmp[i] = f32_to_bf16(0.0f);
for (int r = 0; r < 128; r++) tmp[r * 16 + 0] = f32_to_bf16((float)(r + 1));
write_canonical(sA, tmp, 128, 16);
asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};"
:: "r"(tb + 0),
"r"(ivals[0]), "r"(ivals[1]), "r"(ivals[2]), "r"(ivals[3]),
"r"(ivals[4]), "r"(ivals[5]), "r"(ivals[6]), "r"(ivals[7]));
asm volatile("tcgen05.wait::st.sync.aligned;" ::: "memory");
// B: B[col, 0] = 1.0 for col < 128
for (int i = 0; i < 128 * 16; i++) tmp[i] = f32_to_bf16(0.0f);
for (int c = 0; c < 128; c++) tmp[c * 16 + 0] = f32_to_bf16(1.0f);
write_canonical(sB, tmp, 128, 16);
}
__syncthreads();
// Strategy 0: just __syncthreads
// Strategy 1: __syncthreads + tcgen05.wait::st on all warps
// Strategy 2: __syncthreads + fence + __syncthreads
if (strategy == 0) {
__syncthreads();
} else if (strategy == 1) {
__syncthreads();
asm volatile("tcgen05.wait::st.sync.aligned;" ::: "memory");
__syncthreads();
} else if (strategy == 2) {
__syncthreads();
asm volatile("fence.sc.gpu;" ::: "memory");
__syncthreads();
// MMA: A(128,16) × B(128,16) → S(128,128) in TMEM
// S[row, col] should be (row+1) for col < 128
if (wid == 4) {
uint32_t idesc = make_idesc(128, 128);
uint64_t da = make_umma_desc(__cvta_generic_to_shared(sA), 128);
uint64_t db = make_umma_desc(__cvta_generic_to_shared(sB), 128);
if (tid == 128) umma_ss(tb, da, db, idesc, false);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
}
__syncthreads();
// Extra fence for TMEM visibility
asm volatile("fence.sc.gpu;" ::: "memory");
__syncthreads();
// All 4 warps read
// 4 warps read TMEM with 32x32b.x8
// Read column group 0 (columns 0-7)
if (wid < 4) {
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
@@ -80,53 +127,60 @@ test_tmem_sync(float* results, int strategy) {
: "r"(tb + 0));
asm volatile("tcgen05.wait::ld.sync.aligned;");
// Lane 0: store col 0 value
// Lane 0: store col 0 value (should be row_id+1)
if (lane == 0) results[wid] = tmp[0];
// Lane 5: store col 0 value (should be 50.0)
if (lane == 5) results[4 + wid] = tmp[0];
// Lane 1
if (lane == 1) results[4 + wid] = tmp[0];
// Lane 15
if (lane == 15) results[8 + wid] = tmp[0];
// Lane 31
if (lane == 31) results[12 + wid] = tmp[0];
// Lane 0, also store col 1 value
if (lane == 0) results[16 + wid] = tmp[1];
}
__syncthreads();
if (wid == 0) tmem_dealloc(tb, TMEM_N);
if (wid == 4) tmem_dealloc(tb, TMEM_N);
}
int main() {
printf("TMEM cross-warp visibility test\n");
printf("================================\n\n");
printf("MMA → 4-warp TMEM read test\n");
printf("============================\n");
printf("After UMMA, S[row, col] = row+1 for col < 128.\n");
printf("If 4 warps see different rows, lane l in warp w should see row w*32+l.\n\n");
for (int strat = 0; strat < 3; strat++) {
const char* names[] = {"__syncthreads only",
"__syncthreads + tcgen05.wait::st + __syncthreads",
"__syncthreads + fence.sc.gpu + __syncthreads"};
printf("Strategy %d: %s\n", strat, names[strat]);
float* d_r;
cudaMalloc(&d_r, 32 * sizeof(float));
cudaMemset(d_r, 0, 32 * sizeof(float));
test_mma_4warp_read<<<1, 192, 4096>>>(d_r);
float* d_r;
cudaMalloc(&d_r, 32 * sizeof(float));
cudaMemset(d_r, 0, 32 * sizeof(float));
test_tmem_sync<<<1, 128, 256>>>(d_r, strat);
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) {
printf(" CUDA ERROR: %s\n", cudaGetErrorString(err));
} else {
float h[32];
cudaMemcpy(h, d_r, 32 * sizeof(float), cudaMemcpyDeviceToHost);
printf(" Lane 0, col 0 (expect 0.0 for all): ");
for(int w=0;w<4;w++) printf("w%d=%.0f ", w, h[w]);
printf("\n");
printf(" Lane 5, col 0 (expect 50.0 for all): ");
for(int w=0;w<4;w++) printf("w%d=%.0f ", w, h[4+w]);
printf("\n");
int ok = 1;
for(int w=0;w<4;w++) {
if (fabsf(h[w] - 0.0f) > 0.01f) ok = 0;
if (fabsf(h[4+w] - 50.0f) > 0.01f) ok = 0;
}
printf(" Result: %s\n\n", ok ? "ALL WARPS SEE DATA" : "SOME WARPS SEE ZEROS");
}
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) {
printf("CUDA ERROR: %s\n", cudaGetErrorString(err));
cudaFree(d_r);
return 1;
}
float h[32];
cudaMemcpy(h, d_r, 32 * sizeof(float), cudaMemcpyDeviceToHost);
printf("Lane 0, col 0 (expect row_id+1): ");
printf("w0=%.1f w1=%.1f w2=%.1f w3=%.1f\n", h[0], h[1], h[2], h[3]);
printf(" If w0=1.0 w1=33.0 w2=65.0 w3=97.0 → each warp sees different 32 rows ✓\n");
printf(" If all same → all warps see same 32 rows\n\n");
printf("Lane 1, col 0: w0=%.1f w1=%.1f w2=%.1f w3=%.1f\n", h[4], h[5], h[6], h[7]);
printf("Lane 15, col 0: w0=%.1f w1=%.1f w2=%.1f w3=%.1f\n", h[8], h[9], h[10], h[11]);
printf("Lane 31, col 0: w0=%.1f w1=%.1f w2=%.1f w3=%.1f\n", h[12], h[13], h[14], h[15]);
printf("Lane 0, col 1: w0=%.1f w1=%.1f w2=%.1f w3=%.1f\n", h[16], h[17], h[18], h[19]);
// Check if warps see different rows
int diff = (h[0] != h[1]) || (h[1] != h[2]) || (h[2] != h[3]);
printf("\nConclusion: %s\n", diff ?
"Warps see DIFFERENT rows — multi-warp softmax works!" :
"Warps see SAME rows — need alternative approach");
cudaFree(d_r);
return 0;
}