From d3dc8cf901965b3c17f8eb82ca4afd35f4b787a6 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Wed, 3 Jun 2026 03:46:14 +0000 Subject: [PATCH] Add prefill T=2 debug CUDA test with intermediate value printing --- tests/unit/test_prefill_t2_debug.cu | 556 ++++++++++++++++++++++++++++ 1 file changed, 556 insertions(+) create mode 100644 tests/unit/test_prefill_t2_debug.cu diff --git a/tests/unit/test_prefill_t2_debug.cu b/tests/unit/test_prefill_t2_debug.cu new file mode 100644 index 00000000..b4d8e523 --- /dev/null +++ b/tests/unit/test_prefill_t2_debug.cu @@ -0,0 +1,556 @@ +/** + * Debug test for B1 prefill kernel T>1 path. + * + * Tests T=2 N=128 step by step: + * 1. Compute QK (noPE + RoPE) for 2 query rows + * 2. Verify QK logits against CPU reference + * 3. Compute softmax + * 4. Compute PV and verify against CPU reference + * 5. Full T=2 prefill vs CPU reference + */ +#include +#include +#include +#include +#include +#include +#include + +// Include kernel headers +#include "dsv4/kernels/attention/fmha_common.cuh" +#include "dsv4/kernels/attention/fmha_umma_desc.cuh" +#include "dsv4/kernels/attention/fmha_mixed_fp8_prefill.cuh" + +using namespace dsv4::kernels::attention; + +// ---- CPU reference functions ---- + +static void cpu_fp8_e4m3_quantize(const float* src, uint8_t* dst, float* scale, + int rows, int cols) { + for (int r = 0; r < rows; r++) { + float amax = 0.0f; + for (int c = 0; c < cols; c++) amax = fmaxf(amax, fabsf(src[r * cols + c])); + float s = amax / 448.0f; + if (s < 1e-12f) s = 1.0f; + scale[r] = s; + for (int c = 0; c < cols; c++) { + float v = src[r * cols + c] / s; + v = fmaxf(-448.0f, fminf(448.0f, v)); + __nv_fp8_e4m3 fp8; fp8.__x = 0; + // Simplest quantize: round to FP8 + memcpy(&fp8, &v, 1); // This won't work, use proper conversion + dst[r * cols + c] = 0; // placeholder + } + } +} + +static float fp8_to_f32(uint8_t b) { + __nv_fp8_e4m3 v; v.__x = b; + return (float)v; +} + +static bf16_t f32_to_bf16_host(float f) { + uint32_t u; memcpy(&u, &f, 4); + uint16_t h = (u + 0x8000) >> 16; + return h; +} + +static float bf16_to_f32_host(bf16_t h) { + uint32_t u = (uint32_t)h << 16; + float f; memcpy(&f, &u, 4); + return f; +} + +// ---- Minimal T=2 kernel that prints intermediate values ---- + +__global__ void prefill_t2_debug_kernel( + const uint8_t* __restrict__ q_nope_fp8, + const float* __restrict__ q_nope_scale, + const bf16_t* __restrict__ q_rope_bf16, + const uint8_t* __restrict__ k_nope_fp8, + const float* __restrict__ k_nope_scale, + const bf16_t* __restrict__ k_rope_bf16, + int T, int N, int HD, int NOPE, int ROPE, + float scale) +{ + // Only one CTA for debug + if (blockIdx.x > 0 || blockIdx.y > 0 || blockIdx.z > 0) return; + + constexpr int SK_TILE = 128; + constexpr int MMA_K_F8 = 32; + constexpr int MMA_K_F16 = 16; + constexpr int NKT_NOPE = 448 / MMA_K_F8; // 14 + constexpr int NKT_ROPE = 64 / MMA_K_F16; // 4 + constexpr int N_SUB = 512 / 16; // 32 + constexpr int NKT_PV = SK_TILE / MMA_K_F16; // 8 + constexpr int TILE_F8 = 128 * MMA_K_F8; // 4096 + constexpr int TILE_F16 = 128 * MMA_K_F16; // 2048 + constexpr int V_SUB_SZ = 16 * MMA_K_F16; // 256 + constexpr int TMEM_COLS = 512; + constexpr int T_ACT = 2; + + const int tid = threadIdx.x; + const int wid = tid >> 5; + const int lane = tid & 31; + const bool is_mma_warp = (wid == 4); + + extern __shared__ __align__(128) char sbuf[]; + size_t off = 0; + uint32_t* sTmemBase = (uint32_t*)(sbuf + off); off += 4; + off = (off + 127) & ~(size_t)127; + uint8_t* sQ8 = (uint8_t*)(sbuf + off); off += TILE_F8; + off = (off + 127) & ~(size_t)127; + uint8_t* sK8 = (uint8_t*)(sbuf + off); off += TILE_F8; + off = (off + 127) & ~(size_t)127; + bf16_t* sQ16 = (bf16_t*)(sbuf + off); off += TILE_F16 * sizeof(bf16_t); + off = (off + 127) & ~(size_t)127; + bf16_t* sK16 = (bf16_t*)(sbuf + off); off += TILE_F16 * sizeof(bf16_t); + off = (off + 127) & ~(size_t)127; + bf16_t* sPk = (bf16_t*)(sbuf + off); off += TILE_F16 * sizeof(bf16_t); + off = (off + 127) & ~(size_t)127; + bf16_t* sV = (bf16_t*)(sbuf + off); off += V_SUB_SZ * sizeof(bf16_t); + off = (off + 127) & ~(size_t)127; + float* sLogits = (float*)(sbuf + off); off += T_ACT * SK_TILE * sizeof(float); + float* sP = (float*)(sbuf + off); off += T_ACT * SK_TILE * sizeof(float); + float* sOacc = (float*)(sbuf + off); off += T_ACT * HD * sizeof(float); + float* sRunningMax = (float*)(sbuf + off); off += T_ACT * sizeof(float); + float* sRunningSum = (float*)(sbuf + off); off += T_ACT * sizeof(float); + + // TMEM alloc + if (is_mma_warp) tmem_alloc((uint32_t)__cvta_generic_to_shared(sTmemBase), TMEM_COLS); + asm volatile("fence.proxy.async.shared::cta;" ::: "memory"); + __syncthreads(); + uint32_t tb = *sTmemBase; + + const uint32_t idesc_f8_qk = make_idesc_f8_e4m3(128, 128); + const uint32_t idesc_f16_qk = make_idesc(128, 128); + const uint32_t idesc_pv = make_idesc(128, 16); + + // Init accumulators + for (int i = tid; i < T_ACT * HD; i += blockDim.x) sOacc[i] = 0.0f; + for (int t = tid; t < T_ACT; t += blockDim.x) { + sRunningMax[t] = -INFINITY; + sRunningSum[t] = 0.0f; + } + __syncthreads(); + + // Single KV tile (N=128) + const int kv_len = min(SK_TILE, N); + + // ---- QK noPE: FP8 ---- + for (int kt = 0; kt < NKT_NOPE; kt++) { + for (int i = tid; i < TILE_F8; i += blockDim.x) { sQ8[i] = 0; sK8[i] = 0; } + __syncthreads(); + for (int r = tid; r < T_ACT; r += blockDim.x) { + for (int c = 0; c < MMA_K_F8; c++) { + int d = kt * MMA_K_F8 + c; + if (d < NOPE) sQ8[_pfill_cidx_f8(r, c)] = q_nope_fp8[r * NOPE + d]; + } + } + for (int i = tid; i < kv_len * MMA_K_F8; i += blockDim.x) { + int r = i / MMA_K_F8, c = i % MMA_K_F8; + int d = kt * MMA_K_F8 + c; + if (d < NOPE) sK8[_pfill_cidx_f8(r, c)] = k_nope_fp8[r * NOPE + d]; + } + __syncthreads(); + if (is_mma_warp && lane == 0) { + uint64_t dq = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sQ8), 128); + uint64_t dk = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sK8), 128); + umma_ss_f8f6f4(tb, dq, dk, idesc_f8_qk, kt > 0); + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + } + __syncthreads(); + } + asm volatile("fence.sc.gpu;" ::: "memory"); + __syncthreads(); + + // Read QK noPE + prefill_read_qk_rows(tb, sLogits, T_ACT, kv_len); + __syncthreads(); + + // Print QK noPE logits for rows 0,1 (first 8 values) + if (tid == 0) { + printf("QK noPE (row 0, first 8): "); + for (int c = 0; c < 8; c++) printf("%.4f ", sLogits[0 * SK_TILE + c]); + printf("\n"); + printf("QK noPE (row 1, first 8): "); + for (int c = 0; c < 8; c++) printf("%.4f ", sLogits[1 * SK_TILE + c]); + printf("\n"); + } + __syncthreads(); + + // Apply scales + for (int r = tid; r < T_ACT; r += blockDim.x) { + float q_s = q_nope_scale[r]; + for (int c = 0; c < kv_len; c++) { + sLogits[r * SK_TILE + c] *= q_s * k_nope_scale[c]; + } + } + __syncthreads(); + + if (tid == 0) { + printf("QK noPE scaled (row 0, first 8): "); + for (int c = 0; c < 8; c++) printf("%.4f ", sLogits[0 * SK_TILE + c]); + printf("\n"); + printf("QK noPE scaled (row 1, first 8): "); + for (int c = 0; c < 8; c++) printf("%.4f ", sLogits[1 * SK_TILE + c]); + printf("\n"); + } + __syncthreads(); + + // ---- QK RoPE: BF16 ---- + for (int kt = 0; kt < NKT_ROPE; kt++) { + for (int i = tid; i < TILE_F16; i += blockDim.x) { sQ16[i] = 0; sK16[i] = 0; } + __syncthreads(); + for (int r = tid; r < T_ACT; r += blockDim.x) { + for (int c = 0; c < MMA_K_F16; c++) { + int d = kt * MMA_K_F16 + c; + if (d < ROPE) sQ16[_pfill_cidx_bf16_128(r, c)] = q_rope_bf16[r * ROPE + d]; + } + } + for (int i = tid; i < kv_len * MMA_K_F16; i += blockDim.x) { + int r = i / MMA_K_F16, c = i % MMA_K_F16; + int d = kt * MMA_K_F16 + c; + if (d < ROPE) sK16[_pfill_cidx_bf16_128(r, c)] = k_rope_bf16[(int64_t)r * ROPE + d]; + } + __syncthreads(); + if (is_mma_warp && lane == 0) { + uint64_t dq = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sQ16), 128); + uint64_t dk = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sK16), 128); + umma_ss_f16(tb, dq, dk, idesc_f16_qk, kt > 0); + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + } + __syncthreads(); + } + asm volatile("fence.sc.gpu;" ::: "memory"); + __syncthreads(); + + // Add RoPE to noPE + prefill_read_qk_rows(tb, sP, T_ACT, kv_len); + __syncthreads(); + for (int i = tid; i < T_ACT * kv_len; i += blockDim.x) { + sLogits[i] += sP[i]; + } + __syncthreads(); + + if (tid == 0) { + printf("QK total (row 0, first 8): "); + for (int c = 0; c < 8; c++) printf("%.4f ", sLogits[0 * SK_TILE + c] * scale); + printf("\n"); + printf("QK total (row 1, first 8): "); + for (int c = 0; c < 8; c++) printf("%.4f ", sLogits[1 * SK_TILE + c] * scale); + printf("\n"); + } + __syncthreads(); + + // ---- Softmax ---- + for (int r = tid; r < T_ACT; r += blockDim.x) { + float tile_max = -INFINITY; + for (int c = 0; c < kv_len; c++) + tile_max = fmaxf(tile_max, sLogits[r * SK_TILE + c] * scale); + + float tile_sum = 0.0f; + for (int c = 0; c < kv_len; c++) { + float pv = expf(sLogits[r * SK_TILE + c] * scale - tile_max); + sP[r * SK_TILE + c] = pv; + tile_sum += pv; + } + for (int c = kv_len; c < SK_TILE; c++) sP[r * SK_TILE + c] = 0.0f; + + float old_max = sRunningMax[r]; + float new_max = fmaxf(old_max, tile_max); + float rescale_old = (old_max > -INFINITY) ? expf(old_max - new_max) : 0.0f; + for (int d = 0; d < HD; d++) sOacc[r * HD + d] *= rescale_old; + float rescale_new = expf(tile_max - new_max); + sRunningSum[r] = sRunningSum[r] * rescale_old + tile_sum * rescale_new; + sRunningMax[r] = new_max; + + sLogits[r * SK_TILE] = rescale_new; + } + __syncthreads(); + + if (tid == 0) { + printf("Softmax P (row 0, first 8): "); + for (int c = 0; c < 8; c++) printf("%.6f ", sP[0 * SK_TILE + c]); + printf(" sum=%.6f\n", sRunningSum[0]); + printf("Softmax P (row 1, first 8): "); + for (int c = 0; c < 8; c++) printf("%.6f ", sP[1 * SK_TILE + c]); + printf(" sum=%.6f\n", sRunningSum[1]); + printf("Rescale: row0=%.6f row1=%.6f\n", sLogits[0 * SK_TILE], sLogits[1 * SK_TILE]); + } + __syncthreads(); + + // ---- PV: per query row ---- + for (int qr = 0; qr < T_ACT; qr++) { + float p_rescale = sLogits[qr * SK_TILE]; + + if (tid == 0) printf("PV for qr=%d: p_rescale=%.6f\n", qr, p_rescale); + + for (int n_sub = 0; n_sub < N_SUB; n_sub++) { + int d_base = n_sub * 16; + for (int pv_kt = 0; pv_kt < NKT_PV; pv_kt++) { + const int col_start = pv_kt * MMA_K_F16; + for (int i = tid; i < TILE_F16; i += blockDim.x) sPk[i] = 0; + for (int i = tid; i < V_SUB_SZ; i += blockDim.x) sV[i] = 0; + __syncthreads(); + + for (int c = tid; c < MMA_K_F16; c += blockDim.x) { + int gc = col_start + c; + sPk[_pfill_cidx_bf16_128(qr, c)] = f32_to_bf16(sP[qr * SK_TILE + gc]); + } + + for (int i = tid; i < 16 * MMA_K_F16; i += blockDim.x) { + int dd = i / MMA_K_F16, kk = i % MMA_K_F16; + int row = col_start + kk; + int g_row = row; + int d = d_base + dd; + bf16_t vbits = 0; + if (row < kv_len) { + if (d < NOPE) { + uint8_t b = k_nope_fp8[(int64_t)g_row * NOPE + d]; + float v = _prefill_fp8_to_f32(b) * k_nope_scale[g_row]; + vbits = f32_to_bf16(v); + } else { + vbits = k_rope_bf16[(int64_t)g_row * ROPE + (d - NOPE)]; + } + } + sV[_pfill_cidx_bf16_16(dd, kk)] = vbits; + } + __syncthreads(); + + bool first = (pv_kt == 0); + if (is_mma_warp && lane == 0) { + uint64_t dp = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sPk), 128); + uint64_t dv = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sV), 16); + umma_ss_f16(tb + n_sub * 16, dp, dv, idesc_pv, !first); + asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory"); + } + __syncthreads(); + } + } + + // Read PV result for row qr + asm volatile("fence.sc.gpu;" ::: "memory"); + __syncthreads(); + prefill_read_pv_all_subs(tb, qr, sOacc, p_rescale); + __syncthreads(); + + // Print first few accumulated values + if (tid == 0 && qr == 0) { + printf("sOacc qr=0 (first 8): "); + for (int d = 0; d < 8; d++) printf("%.6f ", sOacc[0 * HD + d]); + printf("\n"); + } + if (tid == 0 && qr == 1) { + printf("sOacc qr=1 (first 8): "); + for (int d = 0; d < 8; d++) printf("%.6f ", sOacc[1 * HD + d]); + printf("\n"); + } + __syncthreads(); + } + + // Normalize and print final output + if (tid == 0) { + printf("sRunningSum: row0=%.6f row1=%.6f\n", sRunningSum[0], sRunningSum[1]); + printf("sRunningMax: row0=%.6f row1=%.6f\n", sRunningMax[0], sRunningMax[1]); + printf("Final output row0 (first 8): "); + for (int d = 0; d < 8; d++) printf("%.6f ", sOacc[0 * HD + d] / sRunningSum[0]); + printf("\n"); + printf("Final output row1 (first 8): "); + for (int d = 0; d < 8; d++) printf("%.6f ", sOacc[1 * HD + d] / sRunningSum[1]); + printf("\n"); + + // Check for NaN + bool has_nan0 = false, has_nan1 = false; + for (int d = 0; d < HD; d++) { + if (isnan(sOacc[0 * HD + d])) has_nan0 = true; + if (isnan(sOacc[1 * HD + d])) has_nan1 = true; + } + printf("NaN check: row0=%s row1=%s\n", has_nan0 ? "YES" : "no", has_nan1 ? "YES" : "no"); + } + + if (is_mma_warp) tmem_dealloc(tb, TMEM_COLS); +} + +int main() { + constexpr int T = 2; + constexpr int N = 128; + constexpr int HD = 512; + constexpr int NOPE = 448; + constexpr int ROPE = 64; + constexpr float scale = 1.0f / sqrtf((float)HD); + + printf("=== Prefill T=2 Debug Test ===\n"); + printf("T=%d N=%d HD=%d NOPE=%d ROPE=%d scale=%.6f\n", T, N, HD, NOPE, ROPE, scale); + + // Generate random data on CPU, then upload + srand(42); + + // Q: (T, HD) FP32 → quantize noPE to FP8, keep RoPE as BF16 + float* h_q = (float*)malloc(T * HD * sizeof(float)); + for (int i = 0; i < T * HD; i++) h_q[i] = (float)rand() / RAND_MAX * 0.5f - 0.25f; + + // K: (N, HD) FP32 → quantize noPE to FP8, keep RoPE as BF16 + float* h_k = (float*)malloc(N * HD * sizeof(float)); + for (int i = 0; i < N * HD; i++) h_k[i] = (float)rand() / RAND_MAX * 0.5f - 0.25f; + + // Q noPE FP8 quantization (per-row scale) + uint8_t* h_q_nope_fp8 = (uint8_t*)malloc(T * NOPE); + float* h_q_nope_scale = (float*)malloc(T * sizeof(float)); + for (int r = 0; r < T; r++) { + float amax = 0.0f; + for (int c = 0; c < NOPE; c++) amax = fmaxf(amax, fabsf(h_q[r * HD + c])); + float s = amax / 448.0f; + if (s < 1e-12f) s = 1.0f; + h_q_nope_scale[r] = s; + for (int c = 0; c < NOPE; c++) { + float v = h_q[r * HD + c] / s; + v = fmaxf(-448.0f, fminf(448.0f, v)); + __nv_fp8_e4m3 fp8 = __nv_fp8_e4m3(v); + h_q_nope_fp8[r * NOPE + c] = fp8.__x; + } + } + + // Q RoPE BF16 + bf16_t* h_q_rope_bf16 = (bf16_t*)malloc(T * ROPE * sizeof(bf16_t)); + for (int r = 0; r < T; r++) + for (int c = 0; c < ROPE; c++) + h_q_rope_bf16[r * ROPE + c] = f32_to_bf16_host(h_q[r * HD + NOPE + c]); + + // K noPE FP8 quantization + uint8_t* h_k_nope_fp8 = (uint8_t*)malloc(N * NOPE); + float* h_k_nope_scale = (float*)malloc(N * sizeof(float)); + for (int r = 0; r < N; r++) { + float amax = 0.0f; + for (int c = 0; c < NOPE; c++) amax = fmaxf(amax, fabsf(h_k[r * HD + c])); + float s = amax / 448.0f; + if (s < 1e-12f) s = 1.0f; + h_k_nope_scale[r] = s; + for (int c = 0; c < NOPE; c++) { + float v = h_k[r * HD + c] / s; + v = fmaxf(-448.0f, fminf(448.0f, v)); + __nv_fp8_e4m3 fp8 = __nv_fp8_e4m3(v); + h_k_nope_fp8[r * NOPE + c] = fp8.__x; + } + } + + // K RoPE BF16 + bf16_t* h_k_rope_bf16 = (bf16_t*)malloc(N * ROPE * sizeof(bf16_t)); + for (int r = 0; r < N; r++) + for (int c = 0; c < ROPE; c++) + h_k_rope_bf16[r * ROPE + c] = f32_to_bf16_host(h_k[r * HD + NOPE + c]); + + // Upload to GPU + uint8_t *d_q_nope_fp8, *d_k_nope_fp8; + float *d_q_nope_scale, *d_k_nope_scale; + bf16_t *d_q_rope_bf16, *d_k_rope_bf16; + + cudaMalloc(&d_q_nope_fp8, T * NOPE); + cudaMalloc(&d_q_nope_scale, T * sizeof(float)); + cudaMalloc(&d_q_rope_bf16, T * ROPE * sizeof(bf16_t)); + cudaMalloc(&d_k_nope_fp8, N * NOPE); + cudaMalloc(&d_k_nope_scale, N * sizeof(float)); + cudaMalloc(&d_k_rope_bf16, N * ROPE * sizeof(bf16_t)); + + cudaMemcpy(d_q_nope_fp8, h_q_nope_fp8, T * NOPE, cudaMemcpyHostToDevice); + cudaMemcpy(d_q_nope_scale, h_q_nope_scale, T * sizeof(float), cudaMemcpyHostToDevice); + cudaMemcpy(d_q_rope_bf16, h_q_rope_bf16, T * ROPE * sizeof(bf16_t), cudaMemcpyHostToDevice); + cudaMemcpy(d_k_nope_fp8, h_k_nope_fp8, N * NOPE, cudaMemcpyHostToDevice); + cudaMemcpy(d_k_nope_scale, h_k_nope_scale, N * sizeof(float), cudaMemcpyHostToDevice); + cudaMemcpy(d_k_rope_bf16, h_k_rope_bf16, N * ROPE * sizeof(bf16_t), cudaMemcpyHostToDevice); + + // Compute CPU reference QK + printf("\n=== CPU Reference QK ===\n"); + float ref_qk[2][128] = {}; + for (int r = 0; r < T; r++) { + for (int c = 0; c < N; c++) { + float dot = 0.0f; + // noPE: FP8 dequant dot product + for (int d = 0; d < NOPE; d++) { + float qv = fp8_to_f32(h_q_nope_fp8[r * NOPE + d]) * h_q_nope_scale[r]; + float kv = fp8_to_f32(h_k_nope_fp8[c * NOPE + d]) * h_k_nope_scale[c]; + dot += qv * kv; + } + // RoPE: BF16 dot product + for (int d = 0; d < ROPE; d++) { + float qv = bf16_to_f32_host(h_q_rope_bf16[r * ROPE + d]); + float kv = bf16_to_f32_host(h_k_rope_bf16[c * ROPE + d]); + dot += qv * kv; + } + ref_qk[r][c] = dot * scale; + } + } + printf("CPU ref QK (row 0, first 8): "); + for (int c = 0; c < 8; c++) printf("%.4f ", ref_qk[0][c]); + printf("\n"); + printf("CPU ref QK (row 1, first 8): "); + for (int c = 0; c < 8; c++) printf("%.4f ", ref_qk[1][c]); + printf("\n"); + + // Compute CPU reference softmax + printf("\n=== CPU Reference Softmax + Attention ===\n"); + float ref_softmax[2][128] = {}; + for (int r = 0; r < T; r++) { + float mx = ref_qk[r][0]; + for (int c = 1; c < N; c++) mx = fmaxf(mx, ref_qk[r][c]); + float sm = 0.0f; + for (int c = 0; c < N; c++) { + ref_softmax[r][c] = expf(ref_qk[r][c] - mx); + sm += ref_softmax[r][c]; + } + for (int c = 0; c < N; c++) ref_softmax[r][c] /= sm; + } + printf("CPU ref softmax (row 0, first 8): "); + for (int c = 0; c < 8; c++) printf("%.6f ", ref_softmax[0][c]); + printf("\n"); + + // Compute CPU reference attention output + float ref_out[2][512] = {}; + for (int r = 0; r < T; r++) { + for (int d = 0; d < HD; d++) { + float val = 0.0f; + for (int c = 0; c < N; c++) { + float kv; + if (d < NOPE) { + kv = fp8_to_f32(h_k_nope_fp8[c * NOPE + d]) * h_k_nope_scale[c]; + } else { + kv = bf16_to_f32_host(h_k_rope_bf16[c * ROPE + (d - NOPE)]); + } + val += ref_softmax[r][c] * kv; + } + ref_out[r][d] = val; + } + } + printf("CPU ref output (row 0, first 8): "); + for (int d = 0; d < 8; d++) printf("%.6f ", ref_out[0][d]); + printf("\n"); + printf("CPU ref output (row 1, first 8): "); + for (int d = 0; d < 8; d++) printf("%.6f ", ref_out[1][d]); + printf("\n"); + + // Launch debug kernel + printf("\n=== GPU Kernel Execution ===\n"); + int smem_size = 256 * 1024; // generous + prefill_t2_debug_kernel<<>>( + d_q_nope_fp8, d_q_nope_scale, d_q_rope_bf16, + d_k_nope_fp8, d_k_nope_scale, d_k_rope_bf16, + T, N, HD, NOPE, ROPE, scale); + cudaDeviceSynchronize(); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("Kernel launch FAILED: %s\n", cudaGetErrorString(err)); + } else { + printf("Kernel completed successfully.\n"); + } + + // Cleanup + cudaFree(d_q_nope_fp8); cudaFree(d_q_nope_scale); cudaFree(d_q_rope_bf16); + cudaFree(d_k_nope_fp8); cudaFree(d_k_nope_scale); cudaFree(d_k_rope_bf16); + free(h_q); free(h_k); + free(h_q_nope_fp8); free(h_q_nope_scale); free(h_q_rope_bf16); + free(h_k_nope_fp8); free(h_k_nope_scale); free(h_k_rope_bf16); + + printf("\n=== Done ===\n"); + return 0; +}