diff --git a/tests/unit/test_umma_qk.cu b/tests/unit/test_umma_qk.cu index cf674729..8d69704d 100644 --- a/tests/unit/test_umma_qk.cu +++ b/tests/unit/test_umma_qk.cu @@ -96,10 +96,19 @@ test_umma_qk_hd16( // K-major NONE: LBO = BLOCK_MN * 16, SBO = 128 uint64_t desc_q = make_umma_desc_kmajor_none(sQ_smem, 128); uint64_t desc_k = make_umma_desc_kmajor_none(sK_smem, 128); - - // Instruction descriptor: dtype=FP32, atype=BF16, btype=BF16, M=128, N=128 uint32_t idesc = make_idesc(128, 128); + // Debug: write descriptor values to output (positions 128-135) + if (tid == 0) { + memcpy(&s_out[128], &desc_q, 8); + memcpy(&s_out[130], &desc_k, 8); + memcpy(&s_out[132], &idesc, 4); + s_out[133] = (float)sQ_smem; + s_out[134] = (float)sK_smem; + s_out[135] = (float)__cvta_generic_to_shared(sQ); + } + __syncthreads(); + // ================================================================ // Call tcgen05.mma SS // ================================================================ @@ -160,7 +169,7 @@ int main() { bf16_t* h_q = (bf16_t*)malloc(HD * sizeof(bf16_t)); bf16_t* h_k = (bf16_t*)malloc(SK * HD * sizeof(bf16_t)); - float* h_s_out = (float*)malloc(SK * sizeof(float)); + float* h_s_out = (float*)malloc(256 * sizeof(float)); float* h_s_scalar = (float*)malloc(SK * sizeof(float)); srand(42); @@ -173,12 +182,12 @@ int main() { float *d_s_out, *d_s_scalar; cudaMalloc(&d_q, HD * sizeof(bf16_t)); cudaMalloc(&d_k, SK * HD * sizeof(bf16_t)); - cudaMalloc(&d_s_out, SK * sizeof(float)); + cudaMalloc(&d_s_out, 256 * sizeof(float)); cudaMalloc(&d_s_scalar, SK * sizeof(float)); cudaMemcpy(d_q, h_q, HD * sizeof(bf16_t), cudaMemcpyHostToDevice); cudaMemcpy(d_k, h_k, SK * HD * sizeof(bf16_t), cudaMemcpyHostToDevice); - cudaMemset(d_s_out, 0, SK * sizeof(float)); + cudaMemset(d_s_out, 0, 256 * sizeof(float)); cudaMemset(d_s_scalar, 0, SK * sizeof(float)); int smem_size = 4 + 16 + 128*16*2 + 128*16*2 + 16*4 + 256; @@ -194,7 +203,7 @@ int main() { return 1; } - cudaMemcpy(h_s_out, d_s_out, SK * sizeof(float), cudaMemcpyDeviceToHost); + cudaMemcpy(h_s_out, d_s_out, 256 * sizeof(float), cudaMemcpyDeviceToHost); cudaMemcpy(h_s_scalar, d_s_scalar, SK * sizeof(float), cudaMemcpyDeviceToHost); printf("\nS[0,0..7] (MMA): "); @@ -210,6 +219,23 @@ int main() { } float rel_err = (max_val > 0) ? max_diff / max_val : max_diff; printf("Max abs diff: %.6f, Max rel err: %.6f\n", max_diff, rel_err); + + // Debug: print descriptor values + uint64_t dbg_desc_q, dbg_desc_k; + uint32_t dbg_idesc; + memcpy(&dbg_desc_q, &h_s_out[128], 8); + memcpy(&dbg_desc_k, &h_s_out[130], 8); + memcpy(&dbg_idesc, &h_s_out[132], 4); + printf("desc_q = 0x%016lx\n", dbg_desc_q); + printf("desc_k = 0x%016lx\n", dbg_desc_k); + printf("idesc = 0x%08x\n", dbg_idesc); + printf("sQ_smem = %.0f\n", h_s_out[133]); + printf("sK_smem = %.0f\n", h_s_out[134]); + printf(" desc_q start_addr = %lu, LBO = %lu, SBO = %lu\n", + dbg_desc_q & 0x3FFF, (dbg_desc_q >> 16) & 0x3FFF, (dbg_desc_q >> 32) & 0x3FFF); + printf(" desc_k start_addr = %lu, LBO = %lu, SBO = %lu\n", + dbg_desc_k & 0x3FFF, (dbg_desc_k >> 16) & 0x3FFF, (dbg_desc_k >> 32) & 0x3FFF); + printf("Test %s\n", rel_err < 0.01f ? "PASSED" : "FAILED"); cudaFree(d_q); cudaFree(d_k); cudaFree(d_s_out); cudaFree(d_s_scalar);