ROOT CAUSE of TMET hang: tcgen05.fence.cta_group::1.sync.aligned is NOT a valid PTX instruction. The correct TMEM ordering primitives are: - tcgen05.wait::st.sync.aligned (wait for TMEM stores to complete) - tcgen05.wait::ld.sync.aligned (wait for TMEM loads to complete) Found in cutlass/arch/barrier.h fence_view_async_tmem_store/load.
124 lines
4.1 KiB
Plaintext
124 lines
4.1 KiB
Plaintext
/**
|
|
* Minimal TMEM test — isolate the hang.
|
|
* Test: alloc, store, load, dealloc.
|
|
* Run on B200 via: nvcc -std=c++20 -gencode=arch=compute_100a,code=sm_100a test_tmem_minimal.cu -o test_tmem -lcudart
|
|
*/
|
|
#include <cuda_runtime.h>
|
|
#include <cstdint>
|
|
#include <cstdio>
|
|
#include <cstring>
|
|
|
|
// TMEM alloc: MUST be called by a FULLY ACTIVE WARP (all 32 lanes)
|
|
// num_columns: 32-512, power of 2
|
|
// The alloc WRITES the tmem base pointer to the SMEM location pointed to by smem_ptr
|
|
__device__ void tmem_alloc(uint32_t smem_ptr, int num_cols) {
|
|
asm volatile("tcgen05.alloc.cta_group::1.sync.aligned.shared::cta.b32 [%0], %1;"
|
|
:: "r"(smem_ptr), "r"(num_cols));
|
|
}
|
|
|
|
__device__ void tmem_dealloc(uint32_t tmem_ptr, int num_cols) {
|
|
asm volatile("tcgen05.dealloc.cta_group::1.sync.aligned.b32 %0, %1;"
|
|
:: "r"(tmem_ptr), "r"(num_cols));
|
|
}
|
|
|
|
__device__ void tmem_fence_store() {
|
|
asm volatile("tcgen05.wait::st.sync.aligned;" ::: "memory");
|
|
}
|
|
|
|
__device__ void tmem_fence_load() {
|
|
asm volatile("tcgen05.wait::ld.sync.aligned;" ::: "memory");
|
|
}
|
|
|
|
__device__ void tmem_store(uint32_t col, uint32_t r0, uint32_t r1, uint32_t r2, uint32_t r3) {
|
|
asm volatile("tcgen05.st.sync.aligned.16x256b.x1.b32 [%0], {%1, %2, %3, %4};"
|
|
:: "r"(col), "r"(r0), "r"(r1), "r"(r2), "r"(r3));
|
|
}
|
|
|
|
__device__ void tmem_load(uint32_t col, uint32_t &r0, uint32_t &r1, uint32_t &r2, uint32_t &r3) {
|
|
asm volatile("tcgen05.ld.sync.aligned.16x256b.x1.b32 {%0, %1, %2, %3}, [%4];"
|
|
: "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3) : "r"(col));
|
|
}
|
|
|
|
__global__ void test_tmem_alloc_only() {
|
|
extern __shared__ char sbuf[];
|
|
uint32_t* tmem_base_ptr = (uint32_t*)sbuf;
|
|
|
|
int wid = threadIdx.x / 32;
|
|
if (wid == 0) {
|
|
uint32_t smem_ptr = __cvta_generic_to_shared(tmem_base_ptr);
|
|
tmem_alloc(smem_ptr, 32);
|
|
}
|
|
__syncthreads();
|
|
|
|
uint32_t tmem_base = *tmem_base_ptr;
|
|
if (threadIdx.x == 0) printf("alloc_only: tmem_base = %u\n", tmem_base);
|
|
|
|
if (wid == 0) {
|
|
tmem_dealloc(tmem_base, 32);
|
|
}
|
|
if (threadIdx.x == 0) printf("alloc_only: PASS (no hang)\n");
|
|
}
|
|
|
|
__global__ void test_tmem_store_load() {
|
|
extern __shared__ char sbuf[];
|
|
uint32_t* tmem_base_ptr = (uint32_t*)sbuf;
|
|
|
|
int wid = threadIdx.x / 32;
|
|
if (wid == 0) {
|
|
uint32_t smem_ptr = __cvta_generic_to_shared(tmem_base_ptr);
|
|
tmem_alloc(smem_ptr, 32);
|
|
}
|
|
__syncthreads();
|
|
|
|
uint32_t tmem_base = *tmem_base_ptr;
|
|
if (threadIdx.x == 0) printf("store_load: tmem_base = %u\n", tmem_base);
|
|
|
|
// Warp 0, lane 0: store 4 floats to TMEM column 0
|
|
// tcgen05.st is a warp-collective op — all 16 lanes of a half-warp participate
|
|
// Lane 0 writes, others write 0 (or their own values)
|
|
if (wid == 0) {
|
|
float f0 = 1.0f, f1 = 2.0f, f2 = 3.0f, f3 = 4.0f;
|
|
uint32_t u0, u1, u2, u3;
|
|
memcpy(&u0, &f0, 4); memcpy(&u1, &f1, 4);
|
|
memcpy(&u2, &f2, 4); memcpy(&u3, &f3, 4);
|
|
// Column address = tmem_base + column_index
|
|
// Column 0, row group 0 (rows 0-15)
|
|
uint32_t col_addr = tmem_base + 0;
|
|
tmem_store(col_addr, u0, u1, u2, u3);
|
|
}
|
|
tmem_fence_store();
|
|
__syncthreads();
|
|
|
|
// Read back
|
|
if (wid == 0) {
|
|
uint32_t u0, u1, u2, u3;
|
|
uint32_t col_addr = tmem_base + 0;
|
|
tmem_load(col_addr, u0, u1, u2, u3);
|
|
float f0, f1, f2, f3;
|
|
memcpy(&f0, &u0, 4); memcpy(&f1, &u1, 4);
|
|
memcpy(&f2, &u2, 4); memcpy(&f3, &u3, 4);
|
|
if (threadIdx.x == 0) printf("store_load: read %f %f %f %f\n", f0, f1, f2, f3);
|
|
}
|
|
__syncthreads();
|
|
|
|
if (wid == 0) tmem_dealloc(tmem_base, 32);
|
|
if (threadIdx.x == 0) printf("store_load: PASS\n");
|
|
}
|
|
|
|
int main() {
|
|
printf("=== TMET Minimal Test ===\n\n");
|
|
|
|
printf("Test 1: Alloc + Dealloc only...\n");
|
|
test_tmem_alloc_only<<<1, 64, 256>>>();
|
|
cudaError_t err = cudaDeviceSynchronize();
|
|
if (err != cudaSuccess) { printf("FAIL: %s\n", cudaGetErrorString(err)); return 1; }
|
|
|
|
printf("\nTest 2: Store + Load...\n");
|
|
test_tmem_store_load<<<1, 64, 1024>>>();
|
|
err = cudaDeviceSynchronize();
|
|
if (err != cudaSuccess) { printf("FAIL: %s\n", cudaGetErrorString(err)); return 1; }
|
|
|
|
printf("\nAll tests PASSED!\n");
|
|
return 0;
|
|
}
|