/** * Debug: TMA load Q, write canonical SMEM back to GMEM for verification. * This isolates the Q TMA load + canonical write pipeline. */ #include #include #include #include #include #ifndef HD_VAL #define HD_VAL 64 #endif #include "dsv4/kernels/attention/fmha_common.cuh" #include "dsv4/kernels/attention/fmha_umma_desc.cuh" #include "dsv4/kernels/attention/fmha_tma.cuh" using namespace dsv4::kernels::attention; static bf16_t f32_to_bf16_host(float f) { uint32_t u; memcpy(&u,&f,4); return (uint16_t)(u>>16); } static float bf16_to_f32_host(bf16_t h) { uint32_t u=(uint32_t)h<<16; float f; memcpy(&f,&u,4); return f; } constexpr int HD = HD_VAL; constexpr int NKT = HD / MMA_K_BF16; __global__ void __launch_bounds__(192) test_q_smem_kernel( bf16_t* __restrict__ out_canonical, // (128, HD) canonical layout from SMEM bf16_t* __restrict__ out_rowmajor, // (128, HD) row-major converted from canonical CUtensorMap* __restrict__ tma_q ) { static constexpr int TILE_SZ = 128 * MMA_K_BF16; static constexpr int TMEM_N = 128; static constexpr int TMA_TILE_BYTES = 128 * MMA_K_BF16 * 2; const int tid = threadIdx.x; const int wid = tid / 32; const int lane = tid % 32; const bool is_load_warp = (wid == 5); extern __shared__ __align__(128) char sbuf[]; size_t off = 0; uint32_t* sTmemBase = (uint32_t*)sbuf; off = 4; off = (off + 127) & ~(size_t)127; uint64_t* sMbar = (uint64_t*)(sbuf + off); off += 8; off = (off + 127) & ~(size_t)127; bf16_t* sTmaBuf = (bf16_t*)(sbuf + off); off += TILE_SZ * sizeof(bf16_t); off = (off + 127) & ~(size_t)127; bf16_t* sQ = (bf16_t*)(sbuf + off); off += 128 * HD * sizeof(bf16_t); // Init if (tid == 0) { tma_mbarrier_init((uint32_t)__cvta_generic_to_shared(sMbar), 1); asm volatile("fence.mbarrier_init.release.cluster;" ::: "memory"); } __syncthreads(); const uint32_t mbar_addr = (uint32_t)__cvta_generic_to_shared(sMbar); int phase = 0; // Zero Q canonical if (is_load_warp) { for (int i = lane; i < 128 * HD; i += 32) sQ[i] = 0; } __syncthreads(); // Load full Q via TMA for (int qkt = 0; qkt < NKT; qkt++) { if (is_load_warp && lane == 0) { tma_load_2d((uint32_t)__cvta_generic_to_shared(sTmaBuf), (uint64_t)tma_q, mbar_addr, qkt * MMA_K_BF16, 0); tma_mbarrier_arrive_expect_tx(mbar_addr, TMA_TILE_BYTES); } tma_mbarrier_wait(mbar_addr, phase); phase ^= 1; __syncthreads(); // Write sub-tile to canonical position if (is_load_warp) { constexpr int CORES_MN = 16; for (int i = lane; i < 128 * 16; i += 32) { int r = i / 16, c = i % 16; int core_mn = r / 8, local_r = r % 8; int core_k_sub = c / 8, local_c = c % 8; int core_k_full = qkt * 2 + core_k_sub; int dst_idx = core_k_full * CORES_MN * 64 + core_mn * 64 + local_r * 8 + local_c; sQ[dst_idx] = sTmaBuf[i]; } } __syncthreads(); } // Dump canonical SMEM to GMEM for (int i = tid; i < 128 * HD; i += 192) { out_canonical[i] = sQ[i]; } // Convert canonical back to row-major and dump // canonical[core_k * CORES_MN * 64 + core_mn * 64 + local_r * 8 + local_c] = value // row-major[r, c] where r = core_mn*8 + local_r, c = core_k*8 + local_c for (int i = tid; i < 128 * HD; i += 192) { // We can't easily reverse the canonical mapping without knowing the structure // Instead, just verify a few known positions } // Use the write_smem_canonical inverse to check: write canonical→row-major conversion // Actually, let's just read from sQ as if it were canonical and convert to row-major // This is the inverse of write_smem_canonical bf16_t* sRowMajor = sTmaBuf; // reuse TMA buffer (only 128*16, too small) // Can't do full conversion in-place. Just dump canonical. } int main() { printf("TMA Q SMEM Debug (HD=%d)\n", HD); const int T = 4; bf16_t* h_q = (bf16_t*)calloc(128 * HD, sizeof(bf16_t)); srand(42); for (int i = 0; i < T * HD; i++) h_q[i] = f32_to_bf16_host((float)(rand()%100)/100.0f - 0.5f); bf16_t *d_q, *d_out; cudaMalloc(&d_q, 128 * HD * sizeof(bf16_t)); cudaMalloc(&d_out, 128 * HD * sizeof(bf16_t)); cudaMemcpy(d_q, h_q, 128 * HD * sizeof(bf16_t), cudaMemcpyHostToDevice); CUtensorMap tma_q; CUtensorMap* d_tma_q; create_tma_desc_2d_bf16(&tma_q, d_q, 128, HD, 128, 16); cudaMalloc(&d_tma_q, sizeof(CUtensorMap)); cudaMemcpy(d_tma_q, &tma_q, sizeof(CUtensorMap), cudaMemcpyHostToDevice); int smem = 4 + 8 + 128*16*2 + 128*HD*2 + 4096; test_q_smem_kernel<<<1, 192, smem>>>(d_out, nullptr, d_tma_q); cudaError_t err = cudaDeviceSynchronize(); if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; } bf16_t* h_canon = (bf16_t*)malloc(128 * HD * sizeof(bf16_t)); cudaMemcpy(h_canon, d_out, 128 * HD * sizeof(bf16_t), cudaMemcpyDeviceToHost); // Verify: convert canonical back to row-major and check // canonical[core_k * CORES_MN * 64 + core_mn * 64 + local_r * 8 + local_c] constexpr int CORES_MN = 16; // 128/8 constexpr int CORES_K = HD / 8; int mismatches = 0; int zeros = 0; for (int r = 0; r < 128; r++) { for (int c = 0; c < HD; c++) { int core_mn = r / 8, local_r = r % 8; int core_k = c / 8, local_c = c % 8; int canon_idx = core_k * CORES_MN * 64 + core_mn * 64 + local_r * 8 + local_c; bf16_t expected = h_q[r * HD + c]; bf16_t got = h_canon[canon_idx]; if (got == 0) zeros++; if (expected != got) mismatches++; } } printf("Canonical SMEM: %d mismatches, %d zeros out of %d\n", mismatches, zeros, 128 * HD); // Show first few canonical values printf("First 10 canonical: "); for (int i = 0; i < 10; i++) printf("%d ", (int)h_canon[i]); printf("\n"); printf("First 10 row-major Q: "); for (int i = 0; i < 10; i++) printf("%d ", (int)h_q[i]); printf("\n"); printf("%s\n", mismatches == 0 ? "PASSED" : "FAILED"); return mismatches == 0 ? 0 : 1; }