test: verify 4-warp TMEM read with 32x32b.x8 after MMA

This commit is contained in:
2026-05-28 22:57:59 +00:00
parent 28e04a5ea8
commit 91b03bd6bd

View File

@@ -0,0 +1,215 @@
/**
* Test: verify that 4 warps reading TMEM with 32x32b.x8 each see different rows.
*
* Write a known pattern to TMEM via UMMA (like QK GEMM),
* then have 4 warps read with 32x32b.x8 and check which rows they see.
*
* If warp w naturally sees rows [w*32, (w+1)*32), then multi-row softmax
* works with 4 warps using 32x32b.x8 and NO row offset.
*/
#include <cuda_runtime.h>
#include <cstdio>
#include <cstdint>
#include <cstring>
using bf16_t = unsigned short;
__device__ __forceinline__ bf16_t f32_to_bf16(float f) {
bf16_t h; asm("cvt.rn.bf16.f32 %0, %1;" : "=h"(h) : "f"(f)); return h;
}
__device__ __forceinline__ float bf16_to_f32(bf16_t h) {
float f; asm("cvt.f32.bf16 %0, %1;" : "=f"(f) : "h"(h)); return f;
}
__device__ void tmem_alloc(uint32_t smem_ptr, int num_cols) {
asm volatile("tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;"
:: "r"(smem_ptr), "r"(num_cols));
}
__device__ void tmem_dealloc(uint32_t tmem_ptr, int num_cols) {
asm volatile("tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1;"
:: "r"(tmem_ptr), "r"(num_cols));
}
// Write SMEM in canonical layout for (128, 16) BF16
__device__ void write_smem_canonical_128x16(bf16_t* dst, const bf16_t* src, int rows, int cols) {
const int N = 192; // total threads
const int tid = threadIdx.x;
constexpr int CORES_MN = 16; // 128/8
for (int i = tid; i < 128 * 16; i += N) {
int r = i / 16, c = i % 16;
int core_mn = r / 8, local_r = r % 8;
int core_k = c / 8, local_c = c % 8;
int dst_idx = core_k * CORES_MN * 64 + core_mn * 64 + local_r * 8 + local_c;
dst[dst_idx] = (r < rows && c < cols) ? src[r * cols + c] : f32_to_bf16(0.0f);
}
}
// UMMA descriptor + MMA
__device__ __forceinline__ uint64_t desc_encode(uint64_t byte_val) { return byte_val >> 4; }
__device__ __forceinline__ uint64_t make_umma_desc(uint32_t smem_addr, int block_mn) {
uint64_t desc = 0;
desc |= desc_encode(smem_addr) & 0x3FFF;
desc |= (desc_encode((uint64_t)block_mn * 16) & 0x3FFF) << 16;
desc |= (desc_encode(128ULL) & 0x3FFF) << 32;
desc |= 1ULL << 46;
return desc;
}
__device__ __forceinline__ uint32_t make_idesc(int block_m, int block_n) {
return (1U << 4) | (1U << 7) | (1U << 10) |
((uint32_t)(block_n >> 3) << 17) | ((uint32_t)(block_m >> 4) << 24);
}
__device__ void umma_ss(uint32_t tmem_c, uint64_t da, uint64_t db, uint32_t idesc, bool acc) {
uint32_t sc = acc ? 0x3F800000u : 0u;
asm volatile("{\n\t.reg .pred p;\n\tsetp.ne.b32 p, %4, 0;\n\t"
"tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, {%5,%6,%7,%8}, p;\n\t}"
:: "r"(tmem_c), "l"(da), "l"(db), "r"(idesc), "r"(sc), "r"(0), "r"(0), "r"(0), "r"(0));
}
__global__ void __launch_bounds__(192)
test_tmem_4warp_read(float* results) {
const int tid = threadIdx.x;
const int wid = tid / 32;
const int lane = tid % 32;
const int TMEM_N = 128;
extern __shared__ char sbuf[];
uint32_t* sTmemBase = (uint32_t*)sbuf;
bf16_t* sA = (bf16_t*)(((uintptr_t)(sbuf + 256) + 127) & ~(uintptr_t)127);
bf16_t* sB = sA + 128 * 16; // (128, 16) BF16 canonical
// Create a known pattern: A[row, 0] = row (each row has a unique value in column 0)
// A is (128, 16), B is (128, 16) identity-like
// After MMA: S[row, col] = sum_d A[row, d] * B[col, d]
// If A[row, 0] = row and B has identity structure, S[row, col] should have row info
// Fill A: row r has A[r, 0] = r (just column 0)
if (wid == 5) {
bf16_t tmp[128 * 16];
for (int i = 0; i < 128 * 16; i++) tmp[i] = f32_to_bf16(0.0f);
for (int r = 0; r < 128; r++) tmp[r * 16 + 0] = f32_to_bf16((float)r);
write_smem_canonical_128x16(sA, tmp, 128, 16);
// Fill B: identity-like, B[col, 0] = 1 for col < 128
for (int i = 0; i < 128 * 16; i++) tmp[i] = f32_to_bf16(0.0f);
for (int c = 0; c < 128; c++) tmp[c * 16 + 0] = f32_to_bf16(1.0f);
write_smem_canonical_128x16(sB, tmp, 128, 16);
}
__syncthreads();
// TMEM alloc (warp 4)
if (wid == 4) {
uint32_t sp = __cvta_generic_to_shared(sTmemBase);
tmem_alloc(sp, TMEM_N);
}
__syncthreads();
uint32_t tb = *sTmemBase;
// MMA: A(128,16) × B(128,16) → S(128,128) in TMEM
// S[row, col] should be row if col < 128 (from the identity pattern)
if (wid == 4) {
uint32_t idesc = make_idesc(128, 128);
uint64_t da = make_umma_desc(__cvta_generic_to_shared(sA), 128);
uint64_t db = make_umma_desc(__cvta_generic_to_shared(sB), 128);
if (tid == 128) umma_ss(tb, da, db, idesc, false);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
}
__syncthreads();
// Now have 4 warps read TMEM with 32x32b.x8
// Each warp reads column group 0 (columns 0-7)
// Lane l reports what value it sees
if (wid < 4) {
float tmp[8];
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
: "r"(tb + 0));
asm volatile("tcgen05.wait::ld.sync.aligned;");
// Lane 0 from each warp stores its 8 values
if (lane == 0) {
for (int c = 0; c < 8; c++) {
results[wid * 8 + c] = tmp[c];
}
}
// Lane 1 from each warp
if (lane == 1) {
for (int c = 0; c < 8; c++) {
results[32 + wid * 8 + c] = tmp[c];
}
}
// Lane 16 from each warp (to check if warps see different rows)
if (lane == 16) {
for (int c = 0; c < 8; c++) {
results[64 + wid * 8 + c] = tmp[c];
}
}
}
__syncthreads();
if (wid == 4) tmem_dealloc(tb, TMEM_N);
}
int main() {
printf("TMEM 4-warp read test\n");
printf("=====================\n");
printf("After MMA, 4 warps read TMEM with 32x32b.x8.\n");
printf("If each warp sees different rows, multi-row softmax works.\n\n");
float* d_results;
cudaMalloc(&d_results, 128 * sizeof(float));
cudaMemset(d_results, 0, 128 * sizeof(float));
int smem = 256 + 128 + 128*16*2*2 + 512; // sbuf + sA + sB + slack
test_tmem_4warp_read<<<1, 192, smem>>>(d_results);
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) {
printf("CUDA ERROR: %s\n", cudaGetErrorString(err));
printf("Kernel crashed or hung!\n");
cudaFree(d_results);
return 1;
}
float h_results[128];
cudaMemcpy(h_results, d_results, 128 * sizeof(float), cudaMemcpyDeviceToHost);
printf("Lane 0 from each warp (expecting: warp 0=0.0, warp 1=32.0, warp 2=64.0, warp 3=96.0):\n");
for (int w = 0; w < 4; w++) {
printf(" Warp %d: [0]=%.1f [1]=%.1f [2]=%.1f [3]=%.1f ... [7]=%.1f\n",
w, h_results[w*8+0], h_results[w*8+1], h_results[w*8+2], h_results[w*8+3], h_results[w*8+7]);
}
printf("\nLane 1 from each warp (expecting: warp 0=1.0, warp 1=33.0, warp 2=65.0, warp 3=97.0):\n");
for (int w = 0; w < 4; w++) {
printf(" Warp %d: [0]=%.1f [1]=%.1f ... [7]=%.1f\n",
w, h_results[32+w*8+0], h_results[32+w*8+1], h_results[32+w*8+7]);
}
printf("\nLane 16 from each warp (expecting: warp 0=16.0, warp 1=48.0, warp 2=80.0, warp 3=112.0):\n");
for (int w = 0; w < 4; w++) {
printf(" Warp %d: [0]=%.1f [1]=%.1f ... [7]=%.1f\n",
w, h_results[64+w*8+0], h_results[64+w*8+1], h_results[64+w*8+7]);
}
// Verify: lane 0 of warp 0 should see row 0's value
// If S[row, col] = row (from our identity pattern), then:
// Warp 0, lane 0, col 0 → 0.0
// Warp 1, lane 0, col 0 → 32.0 (if warps see different rows)
// Warp 0, lane 0, col 0 → 0.0 (if all warps see same rows)
int pass = 1;
if (h_results[0] == 0.0f && h_results[8] == 0.0f && h_results[16] == 0.0f && h_results[24] == 0.0f) {
printf("\nAll warps see the SAME rows (rows 0-31). Multi-warp softmax needs a different approach.\n");
} else if (h_results[0] == 0.0f && h_results[8] == 32.0f) {
printf("\nEach warp sees DIFFERENT rows! Multi-warp softmax works with 32x32b.x8!\n");
pass = 1;
} else {
printf("\nUnexpected pattern — need further investigation.\n");
printf(" Warp0[0]=%.1f Warp1[0]=%.1f Warp2[0]=%.1f Warp3[0]=%.1f\n",
h_results[0], h_results[8], h_results[16], h_results[24]);
}
cudaFree(d_results);
return 0;
}