diff --git a/tests/unit/test_fmha_hd64_n16_v2.cu b/tests/unit/test_fmha_hd64_n16_v2.cu index bdf6d756..cb5a9d2a 100644 --- a/tests/unit/test_fmha_hd64_n16_v2.cu +++ b/tests/unit/test_fmha_hd64_n16_v2.cu @@ -41,7 +41,7 @@ test_fmha_hd64_n16(const bf16_t* q, const bf16_t* k, const bf16_t* v, bf16_t* sK0 = sQ0 + NKT_QK * TILE_SZ; bf16_t* sPk = (bf16_t*)(((uintptr_t)(sK0 + NKT_QK * TILE_SZ) + 127) & ~(uintptr_t)127); bf16_t* sV = (bf16_t*)(((uintptr_t)(sPk + TILE_SZ) + 127) & ~(uintptr_t)127); - float* s_p_vals = (float*)(sV + NKT_PV * V_SUB_TILE_SZ); // Only 1 V sub-tile at a time + float* s_p_vals = (float*)(sV + V_SUB_TILE_SZ); // Only 1 V sub-tile at a time // Load Q K-tiles (4 × (128,16) canonical) for (int kt = 0; kt < NKT_QK; kt++) { @@ -234,6 +234,9 @@ int main() { cudaFuncSetAttribute(test_fmha_hd64_n16, cudaFuncAttributeMaxDynamicSharedMemorySize, smem); test_fmha_hd64_n16<<<1, 128, smem>>>(d_q, d_k, d_v, d_o, d_o_scalar, SCALE); + cudaError_t launch_err = cudaGetLastError(); + if (launch_err != cudaSuccess) { printf("LAUNCH ERROR: %s\n", cudaGetErrorString(launch_err)); return 1; } + cudaError_t err = cudaDeviceSynchronize(); if (err != cudaSuccess) { printf("CUDA ERROR: %s\n", cudaGetErrorString(err)); return 1; }