test: verify 4-warp TMEM read with 32x32b.x8 after MMA
This commit is contained in:
215
tests/unit/test_tmem_4warp_read.cu
Normal file
215
tests/unit/test_tmem_4warp_read.cu
Normal file
@@ -0,0 +1,215 @@
|
||||
/**
|
||||
* Test: verify that 4 warps reading TMEM with 32x32b.x8 each see different rows.
|
||||
*
|
||||
* Write a known pattern to TMEM via UMMA (like QK GEMM),
|
||||
* then have 4 warps read with 32x32b.x8 and check which rows they see.
|
||||
*
|
||||
* If warp w naturally sees rows [w*32, (w+1)*32), then multi-row softmax
|
||||
* works with 4 warps using 32x32b.x8 and NO row offset.
|
||||
*/
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cstdio>
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
|
||||
using bf16_t = unsigned short;
|
||||
|
||||
__device__ __forceinline__ bf16_t f32_to_bf16(float f) {
|
||||
bf16_t h; asm("cvt.rn.bf16.f32 %0, %1;" : "=h"(h) : "f"(f)); return h;
|
||||
}
|
||||
__device__ __forceinline__ float bf16_to_f32(bf16_t h) {
|
||||
float f; asm("cvt.f32.bf16 %0, %1;" : "=f"(f) : "h"(h)); return f;
|
||||
}
|
||||
|
||||
__device__ void tmem_alloc(uint32_t smem_ptr, int num_cols) {
|
||||
asm volatile("tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;"
|
||||
:: "r"(smem_ptr), "r"(num_cols));
|
||||
}
|
||||
__device__ void tmem_dealloc(uint32_t tmem_ptr, int num_cols) {
|
||||
asm volatile("tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1;"
|
||||
:: "r"(tmem_ptr), "r"(num_cols));
|
||||
}
|
||||
|
||||
// Write SMEM in canonical layout for (128, 16) BF16
|
||||
__device__ void write_smem_canonical_128x16(bf16_t* dst, const bf16_t* src, int rows, int cols) {
|
||||
const int N = 192; // total threads
|
||||
const int tid = threadIdx.x;
|
||||
constexpr int CORES_MN = 16; // 128/8
|
||||
for (int i = tid; i < 128 * 16; i += N) {
|
||||
int r = i / 16, c = i % 16;
|
||||
int core_mn = r / 8, local_r = r % 8;
|
||||
int core_k = c / 8, local_c = c % 8;
|
||||
int dst_idx = core_k * CORES_MN * 64 + core_mn * 64 + local_r * 8 + local_c;
|
||||
dst[dst_idx] = (r < rows && c < cols) ? src[r * cols + c] : f32_to_bf16(0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
// UMMA descriptor + MMA
|
||||
__device__ __forceinline__ uint64_t desc_encode(uint64_t byte_val) { return byte_val >> 4; }
|
||||
__device__ __forceinline__ uint64_t make_umma_desc(uint32_t smem_addr, int block_mn) {
|
||||
uint64_t desc = 0;
|
||||
desc |= desc_encode(smem_addr) & 0x3FFF;
|
||||
desc |= (desc_encode((uint64_t)block_mn * 16) & 0x3FFF) << 16;
|
||||
desc |= (desc_encode(128ULL) & 0x3FFF) << 32;
|
||||
desc |= 1ULL << 46;
|
||||
return desc;
|
||||
}
|
||||
__device__ __forceinline__ uint32_t make_idesc(int block_m, int block_n) {
|
||||
return (1U << 4) | (1U << 7) | (1U << 10) |
|
||||
((uint32_t)(block_n >> 3) << 17) | ((uint32_t)(block_m >> 4) << 24);
|
||||
}
|
||||
__device__ void umma_ss(uint32_t tmem_c, uint64_t da, uint64_t db, uint32_t idesc, bool acc) {
|
||||
uint32_t sc = acc ? 0x3F800000u : 0u;
|
||||
asm volatile("{\n\t.reg .pred p;\n\tsetp.ne.b32 p, %4, 0;\n\t"
|
||||
"tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, {%5,%6,%7,%8}, p;\n\t}"
|
||||
:: "r"(tmem_c), "l"(da), "l"(db), "r"(idesc), "r"(sc), "r"(0), "r"(0), "r"(0), "r"(0));
|
||||
}
|
||||
|
||||
__global__ void __launch_bounds__(192)
|
||||
test_tmem_4warp_read(float* results) {
|
||||
const int tid = threadIdx.x;
|
||||
const int wid = tid / 32;
|
||||
const int lane = tid % 32;
|
||||
const int TMEM_N = 128;
|
||||
|
||||
extern __shared__ char sbuf[];
|
||||
uint32_t* sTmemBase = (uint32_t*)sbuf;
|
||||
bf16_t* sA = (bf16_t*)(((uintptr_t)(sbuf + 256) + 127) & ~(uintptr_t)127);
|
||||
bf16_t* sB = sA + 128 * 16; // (128, 16) BF16 canonical
|
||||
|
||||
// Create a known pattern: A[row, 0] = row (each row has a unique value in column 0)
|
||||
// A is (128, 16), B is (128, 16) identity-like
|
||||
// After MMA: S[row, col] = sum_d A[row, d] * B[col, d]
|
||||
// If A[row, 0] = row and B has identity structure, S[row, col] should have row info
|
||||
|
||||
// Fill A: row r has A[r, 0] = r (just column 0)
|
||||
if (wid == 5) {
|
||||
bf16_t tmp[128 * 16];
|
||||
for (int i = 0; i < 128 * 16; i++) tmp[i] = f32_to_bf16(0.0f);
|
||||
for (int r = 0; r < 128; r++) tmp[r * 16 + 0] = f32_to_bf16((float)r);
|
||||
write_smem_canonical_128x16(sA, tmp, 128, 16);
|
||||
|
||||
// Fill B: identity-like, B[col, 0] = 1 for col < 128
|
||||
for (int i = 0; i < 128 * 16; i++) tmp[i] = f32_to_bf16(0.0f);
|
||||
for (int c = 0; c < 128; c++) tmp[c * 16 + 0] = f32_to_bf16(1.0f);
|
||||
write_smem_canonical_128x16(sB, tmp, 128, 16);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// TMEM alloc (warp 4)
|
||||
if (wid == 4) {
|
||||
uint32_t sp = __cvta_generic_to_shared(sTmemBase);
|
||||
tmem_alloc(sp, TMEM_N);
|
||||
}
|
||||
__syncthreads();
|
||||
uint32_t tb = *sTmemBase;
|
||||
|
||||
// MMA: A(128,16) × B(128,16) → S(128,128) in TMEM
|
||||
// S[row, col] should be row if col < 128 (from the identity pattern)
|
||||
if (wid == 4) {
|
||||
uint32_t idesc = make_idesc(128, 128);
|
||||
uint64_t da = make_umma_desc(__cvta_generic_to_shared(sA), 128);
|
||||
uint64_t db = make_umma_desc(__cvta_generic_to_shared(sB), 128);
|
||||
if (tid == 128) umma_ss(tb, da, db, idesc, false);
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Now have 4 warps read TMEM with 32x32b.x8
|
||||
// Each warp reads column group 0 (columns 0-7)
|
||||
// Lane l reports what value it sees
|
||||
if (wid < 4) {
|
||||
float tmp[8];
|
||||
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
|
||||
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
|
||||
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
|
||||
: "r"(tb + 0));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
|
||||
// Lane 0 from each warp stores its 8 values
|
||||
if (lane == 0) {
|
||||
for (int c = 0; c < 8; c++) {
|
||||
results[wid * 8 + c] = tmp[c];
|
||||
}
|
||||
}
|
||||
// Lane 1 from each warp
|
||||
if (lane == 1) {
|
||||
for (int c = 0; c < 8; c++) {
|
||||
results[32 + wid * 8 + c] = tmp[c];
|
||||
}
|
||||
}
|
||||
// Lane 16 from each warp (to check if warps see different rows)
|
||||
if (lane == 16) {
|
||||
for (int c = 0; c < 8; c++) {
|
||||
results[64 + wid * 8 + c] = tmp[c];
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (wid == 4) tmem_dealloc(tb, TMEM_N);
|
||||
}
|
||||
|
||||
int main() {
|
||||
printf("TMEM 4-warp read test\n");
|
||||
printf("=====================\n");
|
||||
printf("After MMA, 4 warps read TMEM with 32x32b.x8.\n");
|
||||
printf("If each warp sees different rows, multi-row softmax works.\n\n");
|
||||
|
||||
float* d_results;
|
||||
cudaMalloc(&d_results, 128 * sizeof(float));
|
||||
cudaMemset(d_results, 0, 128 * sizeof(float));
|
||||
|
||||
int smem = 256 + 128 + 128*16*2*2 + 512; // sbuf + sA + sB + slack
|
||||
test_tmem_4warp_read<<<1, 192, smem>>>(d_results);
|
||||
|
||||
cudaError_t err = cudaDeviceSynchronize();
|
||||
if (err != cudaSuccess) {
|
||||
printf("CUDA ERROR: %s\n", cudaGetErrorString(err));
|
||||
printf("Kernel crashed or hung!\n");
|
||||
cudaFree(d_results);
|
||||
return 1;
|
||||
}
|
||||
|
||||
float h_results[128];
|
||||
cudaMemcpy(h_results, d_results, 128 * sizeof(float), cudaMemcpyDeviceToHost);
|
||||
|
||||
printf("Lane 0 from each warp (expecting: warp 0=0.0, warp 1=32.0, warp 2=64.0, warp 3=96.0):\n");
|
||||
for (int w = 0; w < 4; w++) {
|
||||
printf(" Warp %d: [0]=%.1f [1]=%.1f [2]=%.1f [3]=%.1f ... [7]=%.1f\n",
|
||||
w, h_results[w*8+0], h_results[w*8+1], h_results[w*8+2], h_results[w*8+3], h_results[w*8+7]);
|
||||
}
|
||||
|
||||
printf("\nLane 1 from each warp (expecting: warp 0=1.0, warp 1=33.0, warp 2=65.0, warp 3=97.0):\n");
|
||||
for (int w = 0; w < 4; w++) {
|
||||
printf(" Warp %d: [0]=%.1f [1]=%.1f ... [7]=%.1f\n",
|
||||
w, h_results[32+w*8+0], h_results[32+w*8+1], h_results[32+w*8+7]);
|
||||
}
|
||||
|
||||
printf("\nLane 16 from each warp (expecting: warp 0=16.0, warp 1=48.0, warp 2=80.0, warp 3=112.0):\n");
|
||||
for (int w = 0; w < 4; w++) {
|
||||
printf(" Warp %d: [0]=%.1f [1]=%.1f ... [7]=%.1f\n",
|
||||
w, h_results[64+w*8+0], h_results[64+w*8+1], h_results[64+w*8+7]);
|
||||
}
|
||||
|
||||
// Verify: lane 0 of warp 0 should see row 0's value
|
||||
// If S[row, col] = row (from our identity pattern), then:
|
||||
// Warp 0, lane 0, col 0 → 0.0
|
||||
// Warp 1, lane 0, col 0 → 32.0 (if warps see different rows)
|
||||
// Warp 0, lane 0, col 0 → 0.0 (if all warps see same rows)
|
||||
int pass = 1;
|
||||
if (h_results[0] == 0.0f && h_results[8] == 0.0f && h_results[16] == 0.0f && h_results[24] == 0.0f) {
|
||||
printf("\nAll warps see the SAME rows (rows 0-31). Multi-warp softmax needs a different approach.\n");
|
||||
} else if (h_results[0] == 0.0f && h_results[8] == 32.0f) {
|
||||
printf("\nEach warp sees DIFFERENT rows! Multi-warp softmax works with 32x32b.x8!\n");
|
||||
pass = 1;
|
||||
} else {
|
||||
printf("\nUnexpected pattern — need further investigation.\n");
|
||||
printf(" Warp0[0]=%.1f Warp1[0]=%.1f Warp2[0]=%.1f Warp3[0]=%.1f\n",
|
||||
h_results[0], h_results[8], h_results[16], h_results[24]);
|
||||
}
|
||||
|
||||
cudaFree(d_results);
|
||||
return 0;
|
||||
}
|
||||
Reference in New Issue
Block a user