diff --git a/tests/unit/test_umma_qk_hd64.cu b/tests/unit/test_umma_qk_hd64.cu index 327c6cc7..e8d09579 100644 --- a/tests/unit/test_umma_qk_hd64.cu +++ b/tests/unit/test_umma_qk_hd64.cu @@ -54,6 +54,12 @@ test_umma_qk_hd64(const bf16_t* q, const bf16_t* k, } __syncthreads(); + // Simple sanity check: thread 0 writes to sQ[0] + if (tid == 0) sQ[0] = f32_to_bf16(1.0f); + __syncthreads(); + if (tid == 0) s_out[250] = bf16_to_f32(sQ[0]); // Should be 1.0 + __syncthreads(); + // Write Q (1, hd) to row 0 of sQ in canonical layout for (int d = tid; d < hd; d += N_WARPS * 32) { int core_k = d / 8, local_c = d % 8;