test: 16x256b.x1 multiple LOADS — do they crash like stores?
This commit is contained in:
@@ -1,14 +1,10 @@
|
||||
/**
|
||||
* Test: Can 4 warps read TMEM data written by UMMA (not 32x32b.x8 store)?
|
||||
* Test: Can we do 16x256b.x1 LOADS multiple times without crashing?
|
||||
* (The crash was on 16x256b.x1 STORES, not loads.)
|
||||
*
|
||||
* The MMA writes to TMEM via the tcgen05.mma instruction, which uses
|
||||
* the full 128-thread hardware pipeline. This should produce TMEM data
|
||||
* that's visible to all warps in the CTA.
|
||||
*
|
||||
* If this works, then multi-row softmax is possible:
|
||||
* - QK MMA writes S to TMEM (hardware pipeline)
|
||||
* - 4 warps read S with 32x32b.x8 (each sees 32 rows)
|
||||
* - The question is: does each warp see DIFFERENT 32 rows, or the same 32?
|
||||
* If loads work multiple times, we can:
|
||||
* - Use 16x256b.x1 for softmax reads (128 rows, 1 column per call)
|
||||
* - Use 32x32b.x8 for everything else (stores, PV, epilogue)
|
||||
*/
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
@@ -21,9 +17,6 @@ using bf16_t = unsigned short;
|
||||
__device__ __forceinline__ bf16_t f32_to_bf16(float f) {
|
||||
bf16_t h; asm("cvt.rn.bf16.f32 %0, %1;" : "=h"(h) : "f"(f)); return h;
|
||||
}
|
||||
__device__ __forceinline__ float bf16_to_f32(bf16_t h) {
|
||||
float f; asm("cvt.f32.bf16 %0, %1;" : "=f"(f) : "h"(h)); return f;
|
||||
}
|
||||
|
||||
__device__ void tmem_alloc(uint32_t smem_ptr, int num_cols) {
|
||||
asm volatile("tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;"
|
||||
@@ -34,131 +27,125 @@ __device__ void tmem_dealloc(uint32_t tmem_ptr, int num_cols) {
|
||||
:: "r"(tmem_ptr), "r"(num_cols));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ uint64_t desc_encode(uint64_t byte_val) { return byte_val >> 4; }
|
||||
__device__ __forceinline__ uint64_t make_umma_desc(uint32_t smem_addr, int block_mn) {
|
||||
uint64_t desc = 0;
|
||||
desc |= desc_encode(smem_addr) & 0x3FFF;
|
||||
desc |= (desc_encode((uint64_t)block_mn * 16) & 0x3FFF) << 16;
|
||||
desc |= (desc_encode(128ULL) & 0x3FFF) << 32;
|
||||
desc |= 1ULL << 46;
|
||||
return desc;
|
||||
}
|
||||
__device__ __forceinline__ uint32_t make_idesc(int block_m, int block_n) {
|
||||
return (1U << 4) | (1U << 7) | (1U << 10) |
|
||||
((uint32_t)(block_n >> 3) << 17) | ((uint32_t)(block_m >> 4) << 24);
|
||||
}
|
||||
__device__ void umma_ss(uint32_t tmem_c, uint64_t da, uint64_t db, uint32_t idesc, bool acc) {
|
||||
uint32_t sc = acc ? 0x3F800000u : 0u;
|
||||
asm volatile("{\n\t.reg .pred p;\n\tsetp.ne.b32 p, %4, 0;\n\t"
|
||||
"tcgen05.mma.cta_group::1.kind::f16 [%0], %1, %2, %3, {%5,%6,%7,%8}, p;\n\t}"
|
||||
:: "r"(tmem_c), "l"(da), "l"(db), "r"(idesc), "r"(sc), "r"(0), "r"(0), "r"(0), "r"(0));
|
||||
}
|
||||
|
||||
// Write (rows, 16) BF16 matrix to SMEM canonical layout
|
||||
__device__ void write_canonical(bf16_t* dst, const bf16_t* src, int rows, int cols16) {
|
||||
constexpr int CORES_MN = 16;
|
||||
for (int i = threadIdx.x; i < 128 * cols16; i += 192) {
|
||||
int r = i / cols16, c = i % cols16;
|
||||
if (r < rows && c < cols16) {
|
||||
int core_mn = r / 8, local_r = r % 8;
|
||||
int core_k = c / 8, local_c = c % 8;
|
||||
int dst_idx = core_k * CORES_MN * 64 + core_mn * 64 + local_r * 8 + local_c;
|
||||
dst[dst_idx] = src[r * cols16 + c];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void __launch_bounds__(192)
|
||||
test_mma_4warp_read(float* results) {
|
||||
const int tid = threadIdx.x;
|
||||
const int wid = tid / 32;
|
||||
const int lane = tid % 32;
|
||||
__global__ void __launch_bounds__(32)
|
||||
test_16x256b_loads(float* results) {
|
||||
const int lane = threadIdx.x;
|
||||
const int TMEM_N = 128;
|
||||
|
||||
extern __shared__ char sbuf[];
|
||||
uint32_t* sTmemBase = (uint32_t*)sbuf;
|
||||
bf16_t* sA = (bf16_t*)(((uintptr_t)(sbuf + 256) + 127) & ~(uintptr_t)127);
|
||||
bf16_t* sB = sA + 128 * 16;
|
||||
|
||||
// TMEM alloc (warp 4)
|
||||
if (wid == 4) {
|
||||
uint32_t sp = __cvta_generic_to_shared(sTmemBase);
|
||||
tmem_alloc(sp, TMEM_N);
|
||||
}
|
||||
__syncthreads();
|
||||
uint32_t sp = __cvta_generic_to_shared(sTmemBase);
|
||||
tmem_alloc(sp, TMEM_N);
|
||||
__syncwarp();
|
||||
uint32_t tb = *sTmemBase;
|
||||
|
||||
// Build A such that A[row, 0] = row+1 (row 0 = 1.0, row 1 = 2.0, etc.)
|
||||
// This way we can identify which row each lane is reading
|
||||
if (wid == 5) { // load warp
|
||||
bf16_t tmp[128 * 16];
|
||||
for (int i = 0; i < 128 * 16; i++) tmp[i] = f32_to_bf16(0.0f);
|
||||
for (int r = 0; r < 128; r++) tmp[r * 16 + 0] = f32_to_bf16((float)(r + 1));
|
||||
write_canonical(sA, tmp, 128, 16);
|
||||
// Write data via 32x32b.x8 (known working for stores)
|
||||
{
|
||||
float vals[8];
|
||||
for (int c = 0; c < 8; c++) vals[c] = (float)(lane * 10 + c);
|
||||
uint32_t ivals[8];
|
||||
for (int c = 0; c < 8; c++) memcpy(&ivals[c], &vals[c], 4);
|
||||
|
||||
// B: B[col, 0] = 1.0 for col < 128
|
||||
for (int i = 0; i < 128 * 16; i++) tmp[i] = f32_to_bf16(0.0f);
|
||||
for (int c = 0; c < 128; c++) tmp[c * 16 + 0] = f32_to_bf16(1.0f);
|
||||
write_canonical(sB, tmp, 128, 16);
|
||||
// Write to column groups 0-3 (32 columns)
|
||||
for (int n = 0; n < 4; n++) {
|
||||
asm volatile("tcgen05.st.sync.aligned.32x32b.x8.b32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};"
|
||||
:: "r"(tb + n * 8),
|
||||
"r"(ivals[0]), "r"(ivals[1]), "r"(ivals[2]), "r"(ivals[3]),
|
||||
"r"(ivals[4]), "r"(ivals[5]), "r"(ivals[6]), "r"(ivals[7]));
|
||||
asm volatile("tcgen05.wait::st.sync.aligned;" ::: "memory");
|
||||
// Update vals for next column group
|
||||
for (int c = 0; c < 8; c++) vals[c] = (float)(lane * 10 + n * 8 + c);
|
||||
for (int c = 0; c < 8; c++) memcpy(&ivals[c], &vals[c], 4);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// MMA: A(128,16) × B(128,16) → S(128,128) in TMEM
|
||||
// S[row, col] should be (row+1) for col < 128
|
||||
if (wid == 4) {
|
||||
uint32_t idesc = make_idesc(128, 128);
|
||||
uint64_t da = make_umma_desc(__cvta_generic_to_shared(sA), 128);
|
||||
uint64_t db = make_umma_desc(__cvta_generic_to_shared(sB), 128);
|
||||
if (tid == 128) umma_ss(tb, da, db, idesc, false);
|
||||
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
|
||||
}
|
||||
__syncthreads();
|
||||
// Extra fence for TMEM visibility
|
||||
asm volatile("fence.sc.gpu;" ::: "memory");
|
||||
__syncthreads();
|
||||
// Now try reading with 16x256b.x1 loads
|
||||
// 16x256b.x1: lane l reads 4 FP32 values (rows l*4+0..3) from 1 column
|
||||
int load_count = 0;
|
||||
int pass = 1;
|
||||
|
||||
// 4 warps read TMEM with 32x32b.x8
|
||||
// Read column group 0 (columns 0-7)
|
||||
if (wid < 4) {
|
||||
float tmp[8];
|
||||
asm volatile("tcgen05.ld.sync.aligned.32x32b.x8.b32 {%0,%1,%2,%3,%4,%5,%6,%7},[%8];"
|
||||
: "=f"(tmp[0]),"=f"(tmp[1]),"=f"(tmp[2]),"=f"(tmp[3]),
|
||||
"=f"(tmp[4]),"=f"(tmp[5]),"=f"(tmp[6]),"=f"(tmp[7])
|
||||
: "r"(tb + 0));
|
||||
// Read column 0 — lane 0 should get rows 0-3, lane 1 should get rows 4-7, etc.
|
||||
{
|
||||
float v0, v1, v2, v3;
|
||||
asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0, %1, %2, %3}, [%4, %5];"
|
||||
: "=f"(v0), "=f"(v1), "=f"(v2), "=f"(v3)
|
||||
: "r"(tb), "r"(0)); // column 0
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
load_count++;
|
||||
|
||||
// Lane 0: store col 0 value (should be row_id+1)
|
||||
if (lane == 0) results[wid] = tmp[0];
|
||||
// Lane 1
|
||||
if (lane == 1) results[4 + wid] = tmp[0];
|
||||
// Lane 15
|
||||
if (lane == 15) results[8 + wid] = tmp[0];
|
||||
// Lane 31
|
||||
if (lane == 31) results[12 + wid] = tmp[0];
|
||||
|
||||
// Lane 0, also store col 1 value
|
||||
if (lane == 0) results[16 + wid] = tmp[1];
|
||||
// Lane l should see values for rows l*4+0..3
|
||||
// From the store: row 0 col 0 = 0.0, row 1 col 0 = 10.0, row 2 col 0 = 20.0, ...
|
||||
// Wait — the 32x32b.x8 store wrote lane l's data to "row l" in TMEM.
|
||||
// So row l, col 0 = l*10 + 0 = l*10
|
||||
// Lane 0 reads rows 0-3: v0=row0=0, v1=row1=10, v2=row2=20, v3=row3=30
|
||||
if (lane == 0) {
|
||||
results[0] = v0;
|
||||
results[1] = v1;
|
||||
results[2] = v2;
|
||||
results[3] = v3;
|
||||
}
|
||||
if (lane == 1) {
|
||||
results[4] = v0;
|
||||
results[5] = v1;
|
||||
results[6] = v2;
|
||||
results[7] = v3;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (wid == 4) tmem_dealloc(tb, TMEM_N);
|
||||
// Read column 1 (2nd 16x256b.x1 load — does it crash?)
|
||||
{
|
||||
float v0, v1, v2, v3;
|
||||
asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0, %1, %2, %3}, [%4, %5];"
|
||||
: "=f"(v0), "=f"(v1), "=f"(v2), "=f"(v3)
|
||||
: "r"(tb), "r"(1)); // column 1
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
load_count++;
|
||||
|
||||
if (lane == 0) {
|
||||
results[8] = v0;
|
||||
results[9] = v1;
|
||||
results[10] = v2;
|
||||
results[11] = v3;
|
||||
}
|
||||
}
|
||||
|
||||
// Read column 8 (8th column — more 16x256b.x1 loads)
|
||||
{
|
||||
float v0, v1, v2, v3;
|
||||
asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0, %1, %2, %3}, [%4, %5];"
|
||||
: "=f"(v0), "=f"(v1), "=f"(v2), "=f"(v3)
|
||||
: "r"(tb), "r"(8));
|
||||
asm volatile("tcgen05.wait::ld.sync.aligned;");
|
||||
load_count++;
|
||||
|
||||
if (lane == 0) {
|
||||
results[12] = v0; // row 0, col 8 = 0*10+8 = 8.0
|
||||
results[13] = v1; // row 1, col 8 = 10+8 = 18.0? No, 1*10+0=10 + 8 = 18
|
||||
results[14] = v2;
|
||||
results[15] = v3;
|
||||
}
|
||||
}
|
||||
|
||||
// Store load count (if we get here, loads didn't crash)
|
||||
if (lane == 0) results[16] = (float)load_count;
|
||||
|
||||
tmem_dealloc(tb, TMEM_N);
|
||||
}
|
||||
|
||||
int main() {
|
||||
printf("MMA → 4-warp TMEM read test\n");
|
||||
printf("============================\n");
|
||||
printf("After UMMA, S[row, col] = row+1 for col < 128.\n");
|
||||
printf("If 4 warps see different rows, lane l in warp w should see row w*32+l.\n\n");
|
||||
printf("16x256b.x1 multiple LOAD test\n");
|
||||
printf("==============================\n\n");
|
||||
|
||||
float* d_r;
|
||||
cudaMalloc(&d_r, 32 * sizeof(float));
|
||||
cudaMemset(d_r, 0, 32 * sizeof(float));
|
||||
int smem = 256 + 128 + 128*16*2*2 + 256; // sbuf + align + sA + sB + slack
|
||||
test_mma_4warp_read<<<1, 192, smem>>>(d_r);
|
||||
test_16x256b_loads<<<1, 32, 256>>>(d_r);
|
||||
|
||||
cudaError_t err = cudaDeviceSynchronize();
|
||||
if (err != cudaSuccess) {
|
||||
printf("CUDA ERROR: %s\n", cudaGetErrorString(err));
|
||||
printf("16x256b.x1 loads CRASHED (likely on 2nd or 3rd call)\n");
|
||||
cudaFree(d_r);
|
||||
return 1;
|
||||
}
|
||||
@@ -166,22 +153,22 @@ int main() {
|
||||
float h[32];
|
||||
cudaMemcpy(h, d_r, 32 * sizeof(float), cudaMemcpyDeviceToHost);
|
||||
|
||||
printf("Lane 0, col 0 (expect row_id+1): ");
|
||||
printf("w0=%.1f w1=%.1f w2=%.1f w3=%.1f\n", h[0], h[1], h[2], h[3]);
|
||||
printf(" If w0=1.0 w1=33.0 w2=65.0 w3=97.0 → each warp sees different 32 rows ✓\n");
|
||||
printf(" If all same → all warps see same 32 rows\n\n");
|
||||
printf("Load count: %d (3 loads completed = no crash)\n\n", (int)h[16]);
|
||||
|
||||
printf("Lane 1, col 0: w0=%.1f w1=%.1f w2=%.1f w3=%.1f\n", h[4], h[5], h[6], h[7]);
|
||||
printf("Lane 15, col 0: w0=%.1f w1=%.1f w2=%.1f w3=%.1f\n", h[8], h[9], h[10], h[11]);
|
||||
printf("Lane 31, col 0: w0=%.1f w1=%.1f w2=%.1f w3=%.1f\n", h[12], h[13], h[14], h[15]);
|
||||
printf("Lane 0, col 1: w0=%.1f w1=%.1f w2=%.1f w3=%.1f\n", h[16], h[17], h[18], h[19]);
|
||||
printf("Column 0, lane 0 (expect rows 0-3 = 0,10,20,30): %.1f %.1f %.1f %.1f\n",
|
||||
h[0], h[1], h[2], h[3]);
|
||||
printf("Column 0, lane 1 (expect rows 4-7 = 40,50,60,70): %.1f %.1f %.1f %.1f\n",
|
||||
h[4], h[5], h[6], h[7]);
|
||||
printf("Column 1, lane 0 (expect rows 0-3 = 1,11,21,31): %.1f %.1f %.1f %.1f\n",
|
||||
h[8], h[9], h[10], h[11]);
|
||||
printf("Column 8, lane 0 (expect row 0 col 8 = 8, row 1 col 8 = 18, etc): %.1f %.1f %.1f %.1f\n",
|
||||
h[12], h[13], h[14], h[15]);
|
||||
|
||||
// Check if warps see different rows
|
||||
int diff = (h[0] != h[1]) || (h[1] != h[2]) || (h[2] != h[3]);
|
||||
printf("\nConclusion: %s\n", diff ?
|
||||
"Warps see DIFFERENT rows — multi-warp softmax works!" :
|
||||
"Warps see SAME rows — need alternative approach");
|
||||
// Verify: column 0, lane 0 should give [0, 10, 20, 30]
|
||||
int pass = (fabsf(h[0] - 0.0f) < 0.01f) && (fabsf(h[1] - 10.0f) < 0.01f) &&
|
||||
(fabsf(h[2] - 20.0f) < 0.01f) && (fabsf(h[3] - 30.0f) < 0.01f);
|
||||
|
||||
printf("\nResult: %s\n", pass ? "16x256b.x1 LOADS work multiple times!" : "Data mismatch");
|
||||
cudaFree(d_r);
|
||||
return 0;
|
||||
return pass ? 0 : 1;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user