Files
nvfp4-megamoe-kernel/tests/unit/test_prefill_t2_debug.cu

558 lines
21 KiB
Plaintext

/**
* Debug test for B1 prefill kernel T>1 path.
*
* Tests T=2 N=128 step by step:
* 1. Compute QK (noPE + RoPE) for 2 query rows
* 2. Verify QK logits against CPU reference
* 3. Compute softmax
* 4. Compute PV and verify against CPU reference
* 5. Full T=2 prefill vs CPU reference
*/
#include <cuda_runtime.h>
#include <cuda_fp8.h>
#include <cuda_fp8.hpp>
#include <cstdio>
#include <cmath>
#include <cstdlib>
#include <cassert>
// Include kernel headers
#include "dsv4/kernels/attention/fmha_common.cuh"
#include "dsv4/kernels/attention/fmha_umma_desc.cuh"
#include "dsv4/kernels/attention/fmha_mixed_fp8_prefill.cuh"
using namespace dsv4::kernels::attention;
// ---- CPU reference functions ----
static void cpu_fp8_e4m3_quantize(const float* src, uint8_t* dst, float* scale,
int rows, int cols) {
for (int r = 0; r < rows; r++) {
float amax = 0.0f;
for (int c = 0; c < cols; c++) amax = fmaxf(amax, fabsf(src[r * cols + c]));
float s = amax / 448.0f;
if (s < 1e-12f) s = 1.0f;
scale[r] = s;
for (int c = 0; c < cols; c++) {
float v = src[r * cols + c] / s;
v = fmaxf(-448.0f, fminf(448.0f, v));
__nv_fp8_e4m3 fp8; fp8.__x = 0;
// Simplest quantize: round to FP8
memcpy(&fp8, &v, 1); // This won't work, use proper conversion
dst[r * cols + c] = 0; // placeholder
}
}
}
static float fp8_to_f32(uint8_t b) {
__nv_fp8_e4m3 v; v.__x = b;
return (float)v;
}
static bf16_t f32_to_bf16_host(float f) {
uint32_t u; memcpy(&u, &f, 4);
uint16_t h = (u + 0x8000) >> 16;
return h;
}
static float bf16_to_f32_host(bf16_t h) {
uint32_t u = (uint32_t)h << 16;
float f; memcpy(&f, &u, 4);
return f;
}
// ---- Minimal T=2 kernel that prints intermediate values ----
__global__ void prefill_t2_debug_kernel(
const uint8_t* __restrict__ q_nope_fp8,
const float* __restrict__ q_nope_scale,
const bf16_t* __restrict__ q_rope_bf16,
const uint8_t* __restrict__ k_nope_fp8,
const float* __restrict__ k_nope_scale,
const bf16_t* __restrict__ k_rope_bf16,
int T, int N, int HD, int NOPE, int ROPE,
float scale)
{
// Only one CTA for debug
if (blockIdx.x > 0 || blockIdx.y > 0 || blockIdx.z > 0) return;
constexpr int SK_TILE = 128;
constexpr int MMA_K_F8 = 32;
constexpr int MMA_K_F16 = 16;
constexpr int NKT_NOPE = 448 / MMA_K_F8; // 14
constexpr int NKT_ROPE = 64 / MMA_K_F16; // 4
constexpr int N_SUB = 512 / 16; // 32
constexpr int NKT_PV = SK_TILE / MMA_K_F16; // 8
constexpr int TILE_F8 = 128 * MMA_K_F8; // 4096
constexpr int TILE_F16 = 128 * MMA_K_F16; // 2048
constexpr int V_SUB_SZ = 16 * MMA_K_F16; // 256
constexpr int TMEM_COLS = 512;
constexpr int T_ACT = 2;
const int tid = threadIdx.x;
const int wid = tid >> 5;
const int lane = tid & 31;
const bool is_mma_warp = (wid == 4);
extern __shared__ __align__(128) char sbuf[];
size_t off = 0;
uint32_t* sTmemBase = (uint32_t*)(sbuf + off); off += 4;
off = (off + 127) & ~(size_t)127;
uint8_t* sQ8 = (uint8_t*)(sbuf + off); off += TILE_F8;
off = (off + 127) & ~(size_t)127;
uint8_t* sK8 = (uint8_t*)(sbuf + off); off += TILE_F8;
off = (off + 127) & ~(size_t)127;
bf16_t* sQ16 = (bf16_t*)(sbuf + off); off += TILE_F16 * sizeof(bf16_t);
off = (off + 127) & ~(size_t)127;
bf16_t* sK16 = (bf16_t*)(sbuf + off); off += TILE_F16 * sizeof(bf16_t);
off = (off + 127) & ~(size_t)127;
bf16_t* sPk = (bf16_t*)(sbuf + off); off += TILE_F16 * sizeof(bf16_t);
off = (off + 127) & ~(size_t)127;
bf16_t* sV = (bf16_t*)(sbuf + off); off += V_SUB_SZ * sizeof(bf16_t);
off = (off + 127) & ~(size_t)127;
float* sLogits = (float*)(sbuf + off); off += T_ACT * SK_TILE * sizeof(float);
float* sP = (float*)(sbuf + off); off += T_ACT * SK_TILE * sizeof(float);
float* sOacc = (float*)(sbuf + off); off += T_ACT * HD * sizeof(float);
float* sRunningMax = (float*)(sbuf + off); off += T_ACT * sizeof(float);
float* sRunningSum = (float*)(sbuf + off); off += T_ACT * sizeof(float);
// TMEM alloc
if (is_mma_warp) tmem_alloc((uint32_t)__cvta_generic_to_shared(sTmemBase), TMEM_COLS);
asm volatile("fence.proxy.async.shared::cta;" ::: "memory");
__syncthreads();
uint32_t tb = *sTmemBase;
const uint32_t idesc_f8_qk = make_idesc_f8_e4m3(128, 128);
const uint32_t idesc_f16_qk = make_idesc(128, 128);
const uint32_t idesc_pv = make_idesc(128, 16);
// Init accumulators
for (int i = tid; i < T_ACT * HD; i += blockDim.x) sOacc[i] = 0.0f;
for (int t = tid; t < T_ACT; t += blockDim.x) {
sRunningMax[t] = -INFINITY;
sRunningSum[t] = 0.0f;
}
__syncthreads();
// Single KV tile (N=128)
const int kv_len = min(SK_TILE, N);
// ---- QK noPE: FP8 ----
for (int kt = 0; kt < NKT_NOPE; kt++) {
for (int i = tid; i < TILE_F8; i += blockDim.x) { sQ8[i] = 0; sK8[i] = 0; }
__syncthreads();
for (int r = tid; r < T_ACT; r += blockDim.x) {
for (int c = 0; c < MMA_K_F8; c++) {
int d = kt * MMA_K_F8 + c;
if (d < NOPE) sQ8[_pfill_cidx_f8(r, c)] = q_nope_fp8[r * NOPE + d];
}
}
for (int i = tid; i < kv_len * MMA_K_F8; i += blockDim.x) {
int r = i / MMA_K_F8, c = i % MMA_K_F8;
int d = kt * MMA_K_F8 + c;
if (d < NOPE) sK8[_pfill_cidx_f8(r, c)] = k_nope_fp8[r * NOPE + d];
}
__syncthreads();
if (is_mma_warp && lane == 0) {
uint64_t dq = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sQ8), 128);
uint64_t dk = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sK8), 128);
umma_ss_f8f6f4(tb, dq, dk, idesc_f8_qk, kt > 0);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
}
__syncthreads();
}
asm volatile("fence.sc.gpu;" ::: "memory");
__syncthreads();
// Read QK noPE
prefill_read_qk_rows<SK_TILE>(tb, sLogits, T_ACT, kv_len);
__syncthreads();
// Print QK noPE logits for rows 0,1 (first 8 values)
if (tid == 0) {
printf("QK noPE (row 0, first 8): ");
for (int c = 0; c < 8; c++) printf("%.4f ", sLogits[0 * SK_TILE + c]);
printf("\n");
printf("QK noPE (row 1, first 8): ");
for (int c = 0; c < 8; c++) printf("%.4f ", sLogits[1 * SK_TILE + c]);
printf("\n");
}
__syncthreads();
// Apply scales
for (int r = tid; r < T_ACT; r += blockDim.x) {
float q_s = q_nope_scale[r];
for (int c = 0; c < kv_len; c++) {
sLogits[r * SK_TILE + c] *= q_s * k_nope_scale[c];
}
}
__syncthreads();
if (tid == 0) {
printf("QK noPE scaled (row 0, first 8): ");
for (int c = 0; c < 8; c++) printf("%.4f ", sLogits[0 * SK_TILE + c]);
printf("\n");
printf("QK noPE scaled (row 1, first 8): ");
for (int c = 0; c < 8; c++) printf("%.4f ", sLogits[1 * SK_TILE + c]);
printf("\n");
}
__syncthreads();
// ---- QK RoPE: BF16 ----
for (int kt = 0; kt < NKT_ROPE; kt++) {
for (int i = tid; i < TILE_F16; i += blockDim.x) { sQ16[i] = 0; sK16[i] = 0; }
__syncthreads();
for (int r = tid; r < T_ACT; r += blockDim.x) {
for (int c = 0; c < MMA_K_F16; c++) {
int d = kt * MMA_K_F16 + c;
if (d < ROPE) sQ16[_pfill_cidx_bf16_128(r, c)] = q_rope_bf16[r * ROPE + d];
}
}
for (int i = tid; i < kv_len * MMA_K_F16; i += blockDim.x) {
int r = i / MMA_K_F16, c = i % MMA_K_F16;
int d = kt * MMA_K_F16 + c;
if (d < ROPE) sK16[_pfill_cidx_bf16_128(r, c)] = k_rope_bf16[(int64_t)r * ROPE + d];
}
__syncthreads();
if (is_mma_warp && lane == 0) {
uint64_t dq = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sQ16), 128);
uint64_t dk = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sK16), 128);
umma_ss_f16(tb, dq, dk, idesc_f16_qk, kt > 0);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
}
__syncthreads();
}
asm volatile("fence.sc.gpu;" ::: "memory");
__syncthreads();
// Add RoPE to noPE
prefill_read_qk_rows<SK_TILE>(tb, sP, T_ACT, kv_len);
__syncthreads();
for (int i = tid; i < T_ACT * kv_len; i += blockDim.x) {
sLogits[i] += sP[i];
}
__syncthreads();
if (tid == 0) {
printf("QK total (row 0, first 8): ");
for (int c = 0; c < 8; c++) printf("%.4f ", sLogits[0 * SK_TILE + c] * scale);
printf("\n");
printf("QK total (row 1, first 8): ");
for (int c = 0; c < 8; c++) printf("%.4f ", sLogits[1 * SK_TILE + c] * scale);
printf("\n");
}
__syncthreads();
// ---- Softmax ----
for (int r = tid; r < T_ACT; r += blockDim.x) {
float tile_max = -INFINITY;
for (int c = 0; c < kv_len; c++)
tile_max = fmaxf(tile_max, sLogits[r * SK_TILE + c] * scale);
float tile_sum = 0.0f;
for (int c = 0; c < kv_len; c++) {
float pv = expf(sLogits[r * SK_TILE + c] * scale - tile_max);
sP[r * SK_TILE + c] = pv;
tile_sum += pv;
}
for (int c = kv_len; c < SK_TILE; c++) sP[r * SK_TILE + c] = 0.0f;
float old_max = sRunningMax[r];
float new_max = fmaxf(old_max, tile_max);
float rescale_old = (old_max > -INFINITY) ? expf(old_max - new_max) : 0.0f;
for (int d = 0; d < HD; d++) sOacc[r * HD + d] *= rescale_old;
float rescale_new = expf(tile_max - new_max);
sRunningSum[r] = sRunningSum[r] * rescale_old + tile_sum * rescale_new;
sRunningMax[r] = new_max;
sLogits[r * SK_TILE] = rescale_new;
}
__syncthreads();
if (tid == 0) {
printf("Softmax P (row 0, first 8): ");
for (int c = 0; c < 8; c++) printf("%.6f ", sP[0 * SK_TILE + c]);
printf(" sum=%.6f\n", sRunningSum[0]);
printf("Softmax P (row 1, first 8): ");
for (int c = 0; c < 8; c++) printf("%.6f ", sP[1 * SK_TILE + c]);
printf(" sum=%.6f\n", sRunningSum[1]);
printf("Rescale: row0=%.6f row1=%.6f\n", sLogits[0 * SK_TILE], sLogits[1 * SK_TILE]);
}
__syncthreads();
// ---- PV: per query row ----
for (int qr = 0; qr < T_ACT; qr++) {
float p_rescale = sLogits[qr * SK_TILE];
if (tid == 0) printf("PV for qr=%d: p_rescale=%.6f\n", qr, p_rescale);
for (int n_sub = 0; n_sub < N_SUB; n_sub++) {
int d_base = n_sub * 16;
for (int pv_kt = 0; pv_kt < NKT_PV; pv_kt++) {
const int col_start = pv_kt * MMA_K_F16;
for (int i = tid; i < TILE_F16; i += blockDim.x) sPk[i] = 0;
for (int i = tid; i < V_SUB_SZ; i += blockDim.x) sV[i] = 0;
__syncthreads();
for (int c = tid; c < MMA_K_F16; c += blockDim.x) {
int gc = col_start + c;
sPk[_pfill_cidx_bf16_128(qr, c)] = f32_to_bf16(sP[qr * SK_TILE + gc]);
}
for (int i = tid; i < 16 * MMA_K_F16; i += blockDim.x) {
int dd = i / MMA_K_F16, kk = i % MMA_K_F16;
int row = col_start + kk;
int g_row = row;
int d = d_base + dd;
bf16_t vbits = 0;
if (row < kv_len) {
if (d < NOPE) {
uint8_t b = k_nope_fp8[(int64_t)g_row * NOPE + d];
float v = _prefill_fp8_to_f32(b) * k_nope_scale[g_row];
vbits = f32_to_bf16(v);
} else {
vbits = k_rope_bf16[(int64_t)g_row * ROPE + (d - NOPE)];
}
}
sV[_pfill_cidx_bf16_16(dd, kk)] = vbits;
}
__syncthreads();
bool first = (pv_kt == 0);
if (is_mma_warp && lane == 0) {
uint64_t dp = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sPk), 128);
uint64_t dv = make_umma_desc_kmajor_none((uint32_t)__cvta_generic_to_shared(sV), 16);
umma_ss_f16(tb + n_sub * 16, dp, dv, idesc_pv, !first);
asm volatile("tcgen05.fence::after_thread_sync;" ::: "memory");
}
__syncthreads();
}
}
// Read PV result for row qr
asm volatile("fence.sc.gpu;" ::: "memory");
__syncthreads();
prefill_read_pv_all_subs<512, 32>(tb, qr, sOacc, p_rescale);
__syncthreads();
// Print first few accumulated values
if (tid == 0 && qr == 0) {
printf("sOacc qr=0 (first 8): ");
for (int d = 0; d < 8; d++) printf("%.6f ", sOacc[0 * HD + d]);
printf("\n");
}
if (tid == 0 && qr == 1) {
printf("sOacc qr=1 (first 8): ");
for (int d = 0; d < 8; d++) printf("%.6f ", sOacc[1 * HD + d]);
printf("\n");
}
__syncthreads();
}
// Normalize and print final output
if (tid == 0) {
printf("sRunningSum: row0=%.6f row1=%.6f\n", sRunningSum[0], sRunningSum[1]);
printf("sRunningMax: row0=%.6f row1=%.6f\n", sRunningMax[0], sRunningMax[1]);
printf("Final output row0 (first 8): ");
for (int d = 0; d < 8; d++) printf("%.6f ", sOacc[0 * HD + d] / sRunningSum[0]);
printf("\n");
printf("Final output row1 (first 8): ");
for (int d = 0; d < 8; d++) printf("%.6f ", sOacc[1 * HD + d] / sRunningSum[1]);
printf("\n");
// Check for NaN
bool has_nan0 = false, has_nan1 = false;
for (int d = 0; d < HD; d++) {
if (isnan(sOacc[0 * HD + d])) has_nan0 = true;
if (isnan(sOacc[1 * HD + d])) has_nan1 = true;
}
printf("NaN check: row0=%s row1=%s\n", has_nan0 ? "YES" : "no", has_nan1 ? "YES" : "no");
}
if (is_mma_warp) tmem_dealloc(tb, TMEM_COLS);
}
int main() {
constexpr int T = 2;
constexpr int N = 128;
constexpr int HD = 512;
constexpr int NOPE = 448;
constexpr int ROPE = 64;
const float scale = 1.0f / sqrtf((float)HD);
printf("=== Prefill T=2 Debug Test ===\n");
printf("T=%d N=%d HD=%d NOPE=%d ROPE=%d scale=%.6f\n", T, N, HD, NOPE, ROPE, scale);
// Generate random data on CPU, then upload
srand(42);
// Q: (T, HD) FP32 → quantize noPE to FP8, keep RoPE as BF16
float* h_q = (float*)malloc(T * HD * sizeof(float));
for (int i = 0; i < T * HD; i++) h_q[i] = (float)rand() / RAND_MAX * 0.5f - 0.25f;
// K: (N, HD) FP32 → quantize noPE to FP8, keep RoPE as BF16
float* h_k = (float*)malloc(N * HD * sizeof(float));
for (int i = 0; i < N * HD; i++) h_k[i] = (float)rand() / RAND_MAX * 0.5f - 0.25f;
// Q noPE FP8 quantization (per-row scale)
uint8_t* h_q_nope_fp8 = (uint8_t*)malloc(T * NOPE);
float* h_q_nope_scale = (float*)malloc(T * sizeof(float));
for (int r = 0; r < T; r++) {
float amax = 0.0f;
for (int c = 0; c < NOPE; c++) amax = fmaxf(amax, fabsf(h_q[r * HD + c]));
float s = amax / 448.0f;
if (s < 1e-12f) s = 1.0f;
h_q_nope_scale[r] = s;
for (int c = 0; c < NOPE; c++) {
float v = h_q[r * HD + c] / s;
v = fmaxf(-448.0f, fminf(448.0f, v));
__nv_fp8_e4m3 fp8 = __nv_fp8_e4m3(v);
h_q_nope_fp8[r * NOPE + c] = fp8.__x;
}
}
// Q RoPE BF16
bf16_t* h_q_rope_bf16 = (bf16_t*)malloc(T * ROPE * sizeof(bf16_t));
for (int r = 0; r < T; r++)
for (int c = 0; c < ROPE; c++)
h_q_rope_bf16[r * ROPE + c] = f32_to_bf16_host(h_q[r * HD + NOPE + c]);
// K noPE FP8 quantization
uint8_t* h_k_nope_fp8 = (uint8_t*)malloc(N * NOPE);
float* h_k_nope_scale = (float*)malloc(N * sizeof(float));
for (int r = 0; r < N; r++) {
float amax = 0.0f;
for (int c = 0; c < NOPE; c++) amax = fmaxf(amax, fabsf(h_k[r * HD + c]));
float s = amax / 448.0f;
if (s < 1e-12f) s = 1.0f;
h_k_nope_scale[r] = s;
for (int c = 0; c < NOPE; c++) {
float v = h_k[r * HD + c] / s;
v = fmaxf(-448.0f, fminf(448.0f, v));
__nv_fp8_e4m3 fp8 = __nv_fp8_e4m3(v);
h_k_nope_fp8[r * NOPE + c] = fp8.__x;
}
}
// K RoPE BF16
bf16_t* h_k_rope_bf16 = (bf16_t*)malloc(N * ROPE * sizeof(bf16_t));
for (int r = 0; r < N; r++)
for (int c = 0; c < ROPE; c++)
h_k_rope_bf16[r * ROPE + c] = f32_to_bf16_host(h_k[r * HD + NOPE + c]);
// Upload to GPU
uint8_t *d_q_nope_fp8, *d_k_nope_fp8;
float *d_q_nope_scale, *d_k_nope_scale;
bf16_t *d_q_rope_bf16, *d_k_rope_bf16;
cudaMalloc(&d_q_nope_fp8, T * NOPE);
cudaMalloc(&d_q_nope_scale, T * sizeof(float));
cudaMalloc(&d_q_rope_bf16, T * ROPE * sizeof(bf16_t));
cudaMalloc(&d_k_nope_fp8, N * NOPE);
cudaMalloc(&d_k_nope_scale, N * sizeof(float));
cudaMalloc(&d_k_rope_bf16, N * ROPE * sizeof(bf16_t));
cudaMemcpy(d_q_nope_fp8, h_q_nope_fp8, T * NOPE, cudaMemcpyHostToDevice);
cudaMemcpy(d_q_nope_scale, h_q_nope_scale, T * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(d_q_rope_bf16, h_q_rope_bf16, T * ROPE * sizeof(bf16_t), cudaMemcpyHostToDevice);
cudaMemcpy(d_k_nope_fp8, h_k_nope_fp8, N * NOPE, cudaMemcpyHostToDevice);
cudaMemcpy(d_k_nope_scale, h_k_nope_scale, N * sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(d_k_rope_bf16, h_k_rope_bf16, N * ROPE * sizeof(bf16_t), cudaMemcpyHostToDevice);
// Compute CPU reference QK
printf("\n=== CPU Reference QK ===\n");
float ref_qk[2][128] = {};
for (int r = 0; r < T; r++) {
for (int c = 0; c < N; c++) {
float dot = 0.0f;
// noPE: FP8 dequant dot product
for (int d = 0; d < NOPE; d++) {
float qv = fp8_to_f32(h_q_nope_fp8[r * NOPE + d]) * h_q_nope_scale[r];
float kv = fp8_to_f32(h_k_nope_fp8[c * NOPE + d]) * h_k_nope_scale[c];
dot += qv * kv;
}
// RoPE: BF16 dot product
for (int d = 0; d < ROPE; d++) {
float qv = bf16_to_f32_host(h_q_rope_bf16[r * ROPE + d]);
float kv = bf16_to_f32_host(h_k_rope_bf16[c * ROPE + d]);
dot += qv * kv;
}
ref_qk[r][c] = dot * scale;
}
}
printf("CPU ref QK (row 0, first 8): ");
for (int c = 0; c < 8; c++) printf("%.4f ", ref_qk[0][c]);
printf("\n");
printf("CPU ref QK (row 1, first 8): ");
for (int c = 0; c < 8; c++) printf("%.4f ", ref_qk[1][c]);
printf("\n");
// Compute CPU reference softmax
printf("\n=== CPU Reference Softmax + Attention ===\n");
float ref_softmax[2][128] = {};
for (int r = 0; r < T; r++) {
float mx = ref_qk[r][0];
for (int c = 1; c < N; c++) mx = fmaxf(mx, ref_qk[r][c]);
float sm = 0.0f;
for (int c = 0; c < N; c++) {
ref_softmax[r][c] = expf(ref_qk[r][c] - mx);
sm += ref_softmax[r][c];
}
for (int c = 0; c < N; c++) ref_softmax[r][c] /= sm;
}
printf("CPU ref softmax (row 0, first 8): ");
for (int c = 0; c < 8; c++) printf("%.6f ", ref_softmax[0][c]);
printf("\n");
// Compute CPU reference attention output
float ref_out[2][512] = {};
for (int r = 0; r < T; r++) {
for (int d = 0; d < HD; d++) {
float val = 0.0f;
for (int c = 0; c < N; c++) {
float kv;
if (d < NOPE) {
kv = fp8_to_f32(h_k_nope_fp8[c * NOPE + d]) * h_k_nope_scale[c];
} else {
kv = bf16_to_f32_host(h_k_rope_bf16[c * ROPE + (d - NOPE)]);
}
val += ref_softmax[r][c] * kv;
}
ref_out[r][d] = val;
}
}
printf("CPU ref output (row 0, first 8): ");
for (int d = 0; d < 8; d++) printf("%.6f ", ref_out[0][d]);
printf("\n");
printf("CPU ref output (row 1, first 8): ");
for (int d = 0; d < 8; d++) printf("%.6f ", ref_out[1][d]);
printf("\n");
// Launch debug kernel
printf("\n=== GPU Kernel Execution ===\n");
int smem_size = 200 * 1024; // ~149KB needed, stay under 232KB limit
cudaFuncSetAttribute(prefill_t2_debug_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);
prefill_t2_debug_kernel<<<dim3(1,1,1), 192, smem_size>>>(
d_q_nope_fp8, d_q_nope_scale, d_q_rope_bf16,
d_k_nope_fp8, d_k_nope_scale, d_k_rope_bf16,
T, N, HD, NOPE, ROPE, scale);
cudaDeviceSynchronize();
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
printf("Kernel launch FAILED: %s\n", cudaGetErrorString(err));
} else {
printf("Kernel completed successfully.\n");
}
// Cleanup
cudaFree(d_q_nope_fp8); cudaFree(d_q_nope_scale); cudaFree(d_q_rope_bf16);
cudaFree(d_k_nope_fp8); cudaFree(d_k_nope_scale); cudaFree(d_k_rope_bf16);
free(h_q); free(h_k);
free(h_q_nope_fp8); free(h_q_nope_scale); free(h_q_rope_bf16);
free(h_k_nope_fp8); free(h_k_nope_scale); free(h_k_rope_bf16);
printf("\n=== Done ===\n");
return 0;
}