Files
nvfp4-megamoe-kernel/tests/unit/test_q_smem_debug.cu

169 lines
6.2 KiB
Plaintext

/**
* Debug: TMA load Q, write canonical SMEM back to GMEM for verification.
* This isolates the Q TMA load + canonical write pipeline.
*/
#include <cuda_runtime.h>
#include <cuda.h>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#ifndef HD_VAL
#define HD_VAL 64
#endif
#include "dsv4/kernels/attention/fmha_common.cuh"
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
#include "dsv4/kernels/attention/fmha_tma.cuh"
using namespace dsv4::kernels::attention;
static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; }
constexpr int HD = HD_VAL;
constexpr int NKT = HD / MMA_K_BF16;
__global__ void __launch_bounds__(192)
test_q_smem_kernel(
bf16_t* __restrict__ out_canonical, // (128, HD) canonical layout from SMEM
bf16_t* __restrict__ out_rowmajor, // (128, HD) row-major converted from canonical
CUtensorMap* __restrict__ tma_q
) {
static constexpr int TILE_SZ = 128 * MMA_K_BF16;
static constexpr int TMEM_N = 128;
static constexpr int TMA_TILE_BYTES = 128 * MMA_K_BF16 * 2;
const int tid = threadIdx.x;
const int wid = tid / 32;
const int lane = tid % 32;
const bool is_load_warp = (wid == 5);
extern __shared__ __align__(128) char sbuf[];
size_t off = 0;
uint32_t* sTmemBase = (uint32_t*)sbuf; off = 4;
off = (off + 127) & ~(size_t)127;
uint64_t* sMbar = (uint64_t*)(sbuf + off); off += 8;
off = (off + 127) & ~(size_t)127;
bf16_t* sTmaBuf = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t);
off = (off + 127) & ~(size_t)127;
bf16_t* sQ = (bf16_t*)(sbuf + off); off += 128 * HD * sizeof(bf16_t);
// Init
if (tid == 0) {
tma_mbarrier_init((uint32_t)__cvta_generic_to_shared(sMbar), 1);
asm volatile("fence.mbarrier_init.release.cluster;" ::: "memory");
}
__syncthreads();
const uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar);
int phase = 0;
// Zero Q canonical
if (is_load_warp) {
for (int i = lane; i < 128 * HD; i += 32) sQ[i] = 0;
}
__syncthreads();
// Load full Q via TMA
for (int qkt = 0; qkt < NKT; qkt++) {
if (is_load_warp && lane == 0) {
tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)tma_q, mbar_addr, qkt * MMA_K_BF16, 0);
tma_mbarrier_arrive_expect_tx(mbar_addr, TMA_TILE_BYTES);
}
tma_mbarrier_wait(mbar_addr, phase); phase ^= 1;
__syncthreads();
// Write sub-tile to canonical position
if (is_load_warp) {
constexpr int CORES_MN = 16;
for (int i = lane; i < 128 * 16; i += 32) {
int r = i / 16, c = i % 16;
int core_mn = r / 8, local_r = r % 8;
int core_k_sub = c / 8, local_c = c % 8;
int core_k_full = qkt * 2 + core_k_sub;
int dst_idx = core_k_full * CORES_MN * 64 + core_mn * 64 + local_r * 8 + local_c;
sQ[dst_idx] = sTmaBuf[i];
}
}
__syncthreads();
}
// Dump canonical SMEM to GMEM
for (int i = tid; i < 128 * HD; i += 192) {
out_canonical[i] = sQ[i];
}
// Convert canonical back to row-major and dump
// canonical[core_k * CORES_MN * 64 + core_mn * 64 + local_r * 8 + local_c] = value
// row-major[r, c] where r = core_mn*8 + local_r, c = core_k*8 + local_c
for (int i = tid; i < 128 * HD; i += 192) {
// We can't easily reverse the canonical mapping without knowing the structure
// Instead, just verify a few known positions
}
// Use the write_smem_canonical inverse to check: write canonical→row-major conversion
// Actually, let's just read from sQ as if it were canonical and convert to row-major
// This is the inverse of write_smem_canonical
bf16_t* sRowMajor = sTmaBuf; // reuse TMA buffer (only 128*16, too small)
// Can't do full conversion in-place. Just dump canonical.
}
int main() {
printf("TMA Q SMEM Debug (HD=%d)\n", HD);
const int T = 4;
bf16_t* h_q = (bf16_t*)calloc(128 * HD, sizeof(bf16_t));
srand(42);
for (int i = 0; i < T * HD; i++) h_q[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f);
bf16_t *d_q, *d_out;
cudaMalloc(&d_q, 128 * HD * sizeof(bf16_t));
cudaMalloc(&d_out, 128 * HD * sizeof(bf16_t));
cudaMemcpy(d_q, h_q, 128 * HD * sizeof(bf16_t), cudaMemcpyHostToDevice);
CUtensorMap tma_q;
CUtensorMap* d_tma_q;
create_tma_desc_2d_bf16(&tma_q, d_q, 128, HD, 128, 16);
cudaMalloc(&d_tma_q, sizeof(CUtensorMap));
cudaMemcpy(d_tma_q, &tma_q, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
int smem = 4 + 8 + 128*16*2 + 128*HD*2 + 4096;
test_q_smem_kernel<<<1, 192, smem>>>(d_out, nullptr, d_tma_q);
cudaError_t err = cudaDeviceSynchronize();
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }
bf16_t* h_canon = (bf16_t*)malloc(128 * HD * sizeof(bf16_t));
cudaMemcpy(h_canon, d_out, 128 * HD * sizeof(bf16_t), cudaMemcpyDeviceToHost);
// Verify: convert canonical back to row-major and check
// canonical[core_k * CORES_MN * 64 + core_mn * 64 + local_r * 8 + local_c]
constexpr int CORES_MN = 16; // 128/8
constexpr int CORES_K = HD / 8;
int mismatches = 0;
int zeros = 0;
for (int r = 0; r < 128; r++) {
for (int c = 0; c < HD; c++) {
int core_mn = r / 8, local_r = r % 8;
int core_k = c / 8, local_c = c % 8;
int canon_idx = core_k * CORES_MN * 64 + core_mn * 64 + local_r * 8 + local_c;
bf16_t expected = h_q[r * HD + c];
bf16_t got = h_canon[canon_idx];
if (got == 0) zeros++;
if (expected != got) mismatches++;
}
}
printf("Canonical SMEM: %d mismatches, %d zeros out of %d\n", mismatches, zeros, 128 * HD);
// Show first few canonical values
printf("First 10 canonical: ");
for (int i = 0; i < 10; i++) printf("%d ", (int)h_canon[i]);
printf("\n");
printf("First 10 row-major Q: ");
for (int i = 0; i < 10; i++) printf("%d ", (int)h_q[i]);
printf("\n");
printf("%s\n", mismatches == 0 ? "PASSED" : "FAILED");
return mismatches == 0 ? 0 : 1;
}