169 lines
6.2 KiB
Plaintext
169 lines
6.2 KiB
Plaintext
/**
|
|
* Debug: TMA load Q, write canonical SMEM back to GMEM for verification.
|
|
* This isolates the Q TMA load + canonical write pipeline.
|
|
*/
|
|
|
|
#include <cuda_runtime.h>
|
|
#include <cuda.h>
|
|
#include <cstdio>
|
|
#include <cstdlib>
|
|
#include <cstring>
|
|
|
|
#ifndef HD_VAL
|
|
#define HD_VAL 64
|
|
#endif
|
|
|
|
#include "dsv4/kernels/attention/fmha_common.cuh"
|
|
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
|
|
#include "dsv4/kernels/attention/fmha_tma.cuh"
|
|
|
|
using namespace dsv4::kernels::attention;
|
|
|
|
static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); }
|
|
static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; }
|
|
|
|
constexpr int HD = HD_VAL;
|
|
constexpr int NKT = HD / MMA_K_BF16;
|
|
|
|
__global__ void __launch_bounds__(192)
|
|
test_q_smem_kernel(
|
|
bf16_t* __restrict__ out_canonical, // (128, HD) canonical layout from SMEM
|
|
bf16_t* __restrict__ out_rowmajor, // (128, HD) row-major converted from canonical
|
|
CUtensorMap* __restrict__ tma_q
|
|
) {
|
|
static constexpr int TILE_SZ = 128 * MMA_K_BF16;
|
|
static constexpr int TMEM_N = 128;
|
|
static constexpr int TMA_TILE_BYTES = 128 * MMA_K_BF16 * 2;
|
|
|
|
const int tid = threadIdx.x;
|
|
const int wid = tid / 32;
|
|
const int lane = tid % 32;
|
|
const bool is_load_warp = (wid == 5);
|
|
|
|
extern __shared__ __align__(128) char sbuf[];
|
|
size_t off = 0;
|
|
uint32_t* sTmemBase = (uint32_t*)sbuf; off = 4;
|
|
off = (off + 127) & ~(size_t)127;
|
|
uint64_t* sMbar = (uint64_t*)(sbuf + off); off += 8;
|
|
off = (off + 127) & ~(size_t)127;
|
|
bf16_t* sTmaBuf = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t);
|
|
off = (off + 127) & ~(size_t)127;
|
|
bf16_t* sQ = (bf16_t*)(sbuf + off); off += 128 * HD * sizeof(bf16_t);
|
|
|
|
// Init
|
|
if (tid == 0) {
|
|
tma_mbarrier_init((uint32_t)__cvta_generic_to_shared(sMbar), 1);
|
|
asm volatile("fence.mbarrier_init.release.cluster;" ::: "memory");
|
|
}
|
|
__syncthreads();
|
|
const uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar);
|
|
int phase = 0;
|
|
|
|
// Zero Q canonical
|
|
if (is_load_warp) {
|
|
for (int i = lane; i < 128 * HD; i += 32) sQ[i] = 0;
|
|
}
|
|
__syncthreads();
|
|
|
|
// Load full Q via TMA
|
|
for (int qkt = 0; qkt < NKT; qkt++) {
|
|
if (is_load_warp && lane == 0) {
|
|
tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)tma_q, mbar_addr, qkt * MMA_K_BF16, 0);
|
|
tma_mbarrier_arrive_expect_tx(mbar_addr, TMA_TILE_BYTES);
|
|
}
|
|
tma_mbarrier_wait(mbar_addr, phase); phase ^= 1;
|
|
__syncthreads();
|
|
|
|
// Write sub-tile to canonical position
|
|
if (is_load_warp) {
|
|
constexpr int CORES_MN = 16;
|
|
for (int i = lane; i < 128 * 16; i += 32) {
|
|
int r = i / 16, c = i % 16;
|
|
int core_mn = r / 8, local_r = r % 8;
|
|
int core_k_sub = c / 8, local_c = c % 8;
|
|
int core_k_full = qkt * 2 + core_k_sub;
|
|
int dst_idx = core_k_full * CORES_MN * 64 + core_mn * 64 + local_r * 8 + local_c;
|
|
sQ[dst_idx] = sTmaBuf[i];
|
|
}
|
|
}
|
|
__syncthreads();
|
|
}
|
|
|
|
// Dump canonical SMEM to GMEM
|
|
for (int i = tid; i < 128 * HD; i += 192) {
|
|
out_canonical[i] = sQ[i];
|
|
}
|
|
|
|
// Convert canonical back to row-major and dump
|
|
// canonical[core_k * CORES_MN * 64 + core_mn * 64 + local_r * 8 + local_c] = value
|
|
// row-major[r, c] where r = core_mn*8 + local_r, c = core_k*8 + local_c
|
|
for (int i = tid; i < 128 * HD; i += 192) {
|
|
// We can't easily reverse the canonical mapping without knowing the structure
|
|
// Instead, just verify a few known positions
|
|
}
|
|
|
|
// Use the write_smem_canonical inverse to check: write canonical→row-major conversion
|
|
// Actually, let's just read from sQ as if it were canonical and convert to row-major
|
|
// This is the inverse of write_smem_canonical
|
|
bf16_t* sRowMajor = sTmaBuf; // reuse TMA buffer (only 128*16, too small)
|
|
// Can't do full conversion in-place. Just dump canonical.
|
|
}
|
|
|
|
int main() {
|
|
printf("TMA Q SMEM Debug (HD=%d)\n", HD);
|
|
const int T = 4;
|
|
|
|
bf16_t* h_q = (bf16_t*)calloc(128 * HD, sizeof(bf16_t));
|
|
srand(42);
|
|
for (int i = 0; i < T * HD; i++) h_q[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f);
|
|
|
|
bf16_t *d_q, *d_out;
|
|
cudaMalloc(&d_q, 128 * HD * sizeof(bf16_t));
|
|
cudaMalloc(&d_out, 128 * HD * sizeof(bf16_t));
|
|
cudaMemcpy(d_q, h_q, 128 * HD * sizeof(bf16_t), cudaMemcpyHostToDevice);
|
|
|
|
CUtensorMap tma_q;
|
|
CUtensorMap* d_tma_q;
|
|
create_tma_desc_2d_bf16(&tma_q, d_q, 128, HD, 128, 16);
|
|
cudaMalloc(&d_tma_q, sizeof(CUtensorMap));
|
|
cudaMemcpy(d_tma_q, &tma_q, sizeof(CUtensorMap), cudaMemcpyHostToDevice);
|
|
|
|
int smem = 4 + 8 + 128*16*2 + 128*HD*2 + 4096;
|
|
test_q_smem_kernel<<<1, 192, smem>>>(d_out, nullptr, d_tma_q);
|
|
cudaError_t err = cudaDeviceSynchronize();
|
|
if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }
|
|
|
|
bf16_t* h_canon = (bf16_t*)malloc(128 * HD * sizeof(bf16_t));
|
|
cudaMemcpy(h_canon, d_out, 128 * HD * sizeof(bf16_t), cudaMemcpyDeviceToHost);
|
|
|
|
// Verify: convert canonical back to row-major and check
|
|
// canonical[core_k * CORES_MN * 64 + core_mn * 64 + local_r * 8 + local_c]
|
|
constexpr int CORES_MN = 16; // 128/8
|
|
constexpr int CORES_K = HD / 8;
|
|
int mismatches = 0;
|
|
int zeros = 0;
|
|
for (int r = 0; r < 128; r++) {
|
|
for (int c = 0; c < HD; c++) {
|
|
int core_mn = r / 8, local_r = r % 8;
|
|
int core_k = c / 8, local_c = c % 8;
|
|
int canon_idx = core_k * CORES_MN * 64 + core_mn * 64 + local_r * 8 + local_c;
|
|
bf16_t expected = h_q[r * HD + c];
|
|
bf16_t got = h_canon[canon_idx];
|
|
if (got == 0) zeros++;
|
|
if (expected != got) mismatches++;
|
|
}
|
|
}
|
|
printf("Canonical SMEM: %d mismatches, %d zeros out of %d\n", mismatches, zeros, 128 * HD);
|
|
|
|
// Show first few canonical values
|
|
printf("First 10 canonical: ");
|
|
for (int i = 0; i < 10; i++) printf("%d ", (int)h_canon[i]);
|
|
printf("\n");
|
|
printf("First 10 row-major Q: ");
|
|
for (int i = 0; i < 10; i++) printf("%d ", (int)h_q[i]);
|
|
printf("\n");
|
|
|
|
printf("%s\n", mismatches == 0 ? "PASSED" : "FAILED");
|
|
return mismatches == 0 ? 0 : 1;
|
|
}
|