test: TMEM 4 columns, individual store calls + loop load
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
/**
|
||||
* TMEM column addressing test — is tmem_base + 1 a valid column?
|
||||
* TMEM column addressing test — simplified, matches minimal test pattern
|
||||
*/
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
@@ -7,107 +7,77 @@
|
||||
#include <cstdint>
|
||||
#include <cstring>
|
||||
|
||||
typedef unsigned short bf16_t;
|
||||
constexpr int WARP = 32;
|
||||
|
||||
__device__ void tmem_alloc(uint32_t smem_ptr, int num_cols) {
|
||||
asm volatile("tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;"
|
||||
:: "r"(smem_ptr), "r"(num_cols));
|
||||
__device__ void tmem_alloc(uint32_t sp, int n) {
|
||||
asm volatile("tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;" :: "r"(sp), "r"(n));
|
||||
}
|
||||
__device__ void tmem_dealloc(uint32_t tmem_ptr, int num_cols) {
|
||||
asm volatile("tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1;"
|
||||
:: "r"(tmem_ptr), "r"(num_cols));
|
||||
__device__ void tmem_dealloc(uint32_t tp, int n) {
|
||||
asm volatile("tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1;" :: "r"(tp), "r"(n));
|
||||
}
|
||||
__device__ void tmem_store(uint32_t col_addr, uint32_t r0, uint32_t r1, uint32_t r2, uint32_t r3) {
|
||||
asm volatile("tcgen05.st.sync.aligned.16x256b.x1.b32 [%0], {%1, %2, %3, %4};"
|
||||
:: "r"(col_addr), "r"(r0), "r"(r1), "r"(r2), "r"(r3));
|
||||
__device__ void tmem_store(uint32_t c, uint32_t r0, uint32_t r1, uint32_t r2, uint32_t r3) {
|
||||
asm volatile("tcgen05.st.sync.aligned.16x256b.x1.b32 [%0], {%1, %2, %3, %4};" :: "r"(c), "r"(r0), "r"(r1), "r"(r2), "r"(r3));
|
||||
}
|
||||
__device__ void tmem_load(uint32_t col_addr, uint32_t& r0, uint32_t& r1, uint32_t& r2, uint32_t& r3) {
|
||||
asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3) : "r"(col_addr));
|
||||
}
|
||||
__device__ void tmem_fence_store() {
|
||||
asm volatile("tcgen05.wait::st.sync.aligned;" ::: "memory");
|
||||
__device__ void tmem_load(uint32_t c, uint32_t &r0, uint32_t &r1, uint32_t &r2, uint32_t &r3) {
|
||||
asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0, %1, %2, %3}, [%4];" : "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3) : "r"(c));
|
||||
}
|
||||
|
||||
__global__ void test_tmem_cols(float* out) {
|
||||
int tid = threadIdx.x;
|
||||
int lane = tid % WARP;
|
||||
|
||||
__global__ void test_tmem_loop(float* out) {
|
||||
extern __shared__ char sbuf[];
|
||||
uint32_t* sBase = (uint32_t*)sbuf;
|
||||
int lane = threadIdx.x % WARP;
|
||||
|
||||
// Alloc 128 TMEM columns
|
||||
if (tid < 32) {
|
||||
uint32_t sp = __cvta_generic_to_shared(sBase);
|
||||
tmem_alloc(sp, 128);
|
||||
// Alloc 32 TMEM columns (same as minimal test)
|
||||
if (threadIdx.x < 32) {
|
||||
tmem_alloc(__cvta_generic_to_shared(sBase), 32);
|
||||
}
|
||||
__syncthreads();
|
||||
uint32_t tmem_base = *sBase;
|
||||
uint32_t tb = *sBase;
|
||||
|
||||
// Store a unique value to each column
|
||||
// Column i gets value i*1000 + lane*4 + 0..3
|
||||
if (tid < 32) {
|
||||
for (int col = 0; col < 128; col++) {
|
||||
float v0 = (float)(col * 1000 + lane * 4 + 0);
|
||||
float v1 = (float)(col * 1000 + lane * 4 + 1);
|
||||
float v2 = (float)(col * 1000 + lane * 4 + 2);
|
||||
float v3 = (float)(col * 1000 + lane * 4 + 3);
|
||||
uint32_t u0, u1, u2, u3;
|
||||
memcpy(&u0, &v0, 4); memcpy(&u1, &v1, 4);
|
||||
memcpy(&u2, &v2, 4); memcpy(&u3, &v3, 4);
|
||||
tmem_store(tmem_base + col, u0, u1, u2, u3);
|
||||
}
|
||||
tmem_fence_store();
|
||||
// Store to columns 0..3 individually (no loop)
|
||||
if (threadIdx.x < 32) {
|
||||
float v0 = (float)(lane * 4 + 0);
|
||||
float v1 = (float)(lane * 4 + 1);
|
||||
float v2 = (float)(lane * 4 + 2);
|
||||
float v3 = (float)(lane * 4 + 3);
|
||||
uint32_t u0, u1, u2, u3;
|
||||
memcpy(&u0, &v0, 4); memcpy(&u1, &v1, 4);
|
||||
memcpy(&u2, &v2, 4); memcpy(&u3, &v3, 4);
|
||||
|
||||
tmem_store(tb + 0, u0, u1, u2, u3);
|
||||
tmem_store(tb + 1, u0, u1, u2, u3);
|
||||
tmem_store(tb + 2, u0, u1, u2, u3);
|
||||
tmem_store(tb + 3, u0, u1, u2, u3);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Read back first 8 columns (lane 0 only)
|
||||
if (tid < 32) {
|
||||
for (int col = 0; col < 8; col++) {
|
||||
// Read back
|
||||
if (threadIdx.x < 32) {
|
||||
for (int c = 0; c < 4; c++) {
|
||||
uint32_t u0, u1, u2, u3;
|
||||
tmem_load(tmem_base + col, u0, u1, u2, u3);
|
||||
if (lane == 0) {
|
||||
float v0; memcpy(&v0, &u0, 4);
|
||||
out[col] = v0;
|
||||
}
|
||||
tmem_load(tb + c, u0, u1, u2, u3);
|
||||
float v0; memcpy(&v0, &u0, 4);
|
||||
if (lane == 0) out[c] = v0;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Dealloc
|
||||
if (tid < 32) {
|
||||
tmem_dealloc(tmem_base, 128);
|
||||
}
|
||||
if (threadIdx.x < 32) tmem_dealloc(tb, 32);
|
||||
}
|
||||
|
||||
int main() {
|
||||
printf("=== TMEM Column Addressing Test ===\n");
|
||||
float* h_out = (float*)calloc(8, sizeof(float));
|
||||
float* d_out;
|
||||
cudaMalloc(&d_out, 8 * sizeof(float));
|
||||
cudaMemset(d_out, 0, 8 * sizeof(float));
|
||||
printf("=== TMEM Loop Test ===\n");
|
||||
float* h_out = (float*)calloc(4, sizeof(float));
|
||||
float* d_out; cudaMalloc(&d_out, 4 * sizeof(float));
|
||||
cudaMemset(d_out, 0, 4 * sizeof(float));
|
||||
|
||||
test_tmem_cols<<<1, 32, 256>>>(d_out);
|
||||
test_tmem_loop<<<1, 32, 256>>>(d_out);
|
||||
cudaError_t err = cudaDeviceSynchronize();
|
||||
if (err != cudaSuccess) {
|
||||
printf("CUDA ERROR: %s\n", cudaGetErrorString(err));
|
||||
return 1;
|
||||
}
|
||||
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }
|
||||
|
||||
cudaMemcpy(h_out, d_out, 8 * sizeof(float), cudaMemcpyDeviceToHost);
|
||||
printf("Column values (expected 0, 1000, 2000, ...):\n");
|
||||
for (int i = 0; i < 8; i++) {
|
||||
printf(" col %d: %.1f (expected %d.0)\n", i, h_out[i], i * 1000);
|
||||
}
|
||||
|
||||
int ok = 1;
|
||||
for (int i = 0; i < 8; i++) {
|
||||
if (fabsf(h_out[i] - i * 1000.0f) > 0.1f) ok = 0;
|
||||
}
|
||||
printf("Test %s\n", ok ? "PASSED" : "FAILED");
|
||||
|
||||
cudaFree(d_out);
|
||||
free(h_out);
|
||||
return ok ? 0 : 1;
|
||||
cudaMemcpy(h_out, d_out, 4 * sizeof(float), cudaMemcpyDeviceToHost);
|
||||
for (int i = 0; i < 4; i++) printf("col %d: %.1f\n", i, h_out[i]);
|
||||
printf("Test %s\n", h_out[0] == 0.0f ? "PASSED" : "CHECK");
|
||||
cudaFree(d_out); free(h_out);
|
||||
return 0;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user