test_umma_qk: add descriptor debug output

This commit is contained in:
2026-05-28 09:20:12 +00:00
parent 0f6907b001
commit ea6b42e649

View File

@@ -96,10 +96,19 @@ test_umma_qk_hd16(
// K-major NONE: LBO = BLOCK_MN * 16, SBO = 128
uint64_t desc_q = make_umma_desc_kmajor_none(sQ_smem, 128);
uint64_t desc_k = make_umma_desc_kmajor_none(sK_smem, 128);
// Instruction descriptor: dtype=FP32, atype=BF16, btype=BF16, M=128, N=128
uint32_t idesc = make_idesc(128, 128);
// Debug: write descriptor values to output (positions 128-135)
if (tid == 0) {
memcpy(&s_out[128], &desc_q, 8);
memcpy(&s_out[130], &desc_k, 8);
memcpy(&s_out[132], &idesc, 4);
s_out[133] = (float)sQ_smem;
s_out[134] = (float)sK_smem;
s_out[135] = (float)__cvta_generic_to_shared(sQ);
}
__syncthreads();
// ================================================================
// Call tcgen05.mma SS
// ================================================================
@@ -160,7 +169,7 @@ int main() {
bf16_t* h_q = (bf16_t*)malloc(HD * sizeof(bf16_t));
bf16_t* h_k = (bf16_t*)malloc(SK * HD * sizeof(bf16_t));
float* h_s_out = (float*)malloc(SK * sizeof(float));
float* h_s_out = (float*)malloc(256 * sizeof(float));
float* h_s_scalar = (float*)malloc(SK * sizeof(float));
srand(42);
@@ -173,12 +182,12 @@ int main() {
float *d_s_out, *d_s_scalar;
cudaMalloc(&d_q, HD * sizeof(bf16_t));
cudaMalloc(&d_k, SK * HD * sizeof(bf16_t));
cudaMalloc(&d_s_out, SK * sizeof(float));
cudaMalloc(&d_s_out, 256 * sizeof(float));
cudaMalloc(&d_s_scalar, SK * sizeof(float));
cudaMemcpy(d_q, h_q, HD * sizeof(bf16_t), cudaMemcpyHostToDevice);
cudaMemcpy(d_k, h_k, SK * HD * sizeof(bf16_t), cudaMemcpyHostToDevice);
cudaMemset(d_s_out, 0, SK * sizeof(float));
cudaMemset(d_s_out, 0, 256 * sizeof(float));
cudaMemset(d_s_scalar, 0, SK * sizeof(float));
int smem_size = 4 + 16 + 128*16*2 + 128*16*2 + 16*4 + 256;
@@ -194,7 +203,7 @@ int main() {
return 1;
}
cudaMemcpy(h_s_out, d_s_out, SK * sizeof(float), cudaMemcpyDeviceToHost);
cudaMemcpy(h_s_out, d_s_out, 256 * sizeof(float), cudaMemcpyDeviceToHost);
cudaMemcpy(h_s_scalar, d_s_scalar, SK * sizeof(float), cudaMemcpyDeviceToHost);
printf("\nS[0,0..7] (MMA): ");
@@ -210,6 +219,23 @@ int main() {
}
float rel_err = (max_val > 0) ? max_diff / max_val : max_diff;
printf("Max abs diff: %.6f, Max rel err: %.6f\n", max_diff, rel_err);
// Debug: print descriptor values
uint64_t dbg_desc_q, dbg_desc_k;
uint32_t dbg_idesc;
memcpy(&dbg_desc_q, &h_s_out[128], 8);
memcpy(&dbg_desc_k, &h_s_out[130], 8);
memcpy(&dbg_idesc, &h_s_out[132], 4);
printf("desc_q = 0x%016lx\n", dbg_desc_q);
printf("desc_k = 0x%016lx\n", dbg_desc_k);
printf("idesc = 0x%08x\n", dbg_idesc);
printf("sQ_smem = %.0f\n", h_s_out[133]);
printf("sK_smem = %.0f\n", h_s_out[134]);
printf(" desc_q start_addr = %lu, LBO = %lu, SBO = %lu\n",
dbg_desc_q & 0x3FFF, (dbg_desc_q >> 16) & 0x3FFF, (dbg_desc_q >> 32) & 0x3FFF);
printf(" desc_k start_addr = %lu, LBO = %lu, SBO = %lu\n",
dbg_desc_k & 0x3FFF, (dbg_desc_k >> 16) & 0x3FFF, (dbg_desc_k >> 32) & 0x3FFF);
printf("Test %s\n", rel_err < 0.01f ? "PASSED" : "FAILED");
cudaFree(d_q); cudaFree(d_k); cudaFree(d_s_out); cudaFree(d_s_scalar);