debug: print UMMA descriptor values for diagnosis

This commit is contained in:
2026-05-28 08:03:53 +00:00
parent c5f7a9a15c
commit fe7d561143

View File

@@ -78,9 +78,21 @@ fmha_qk_verify(
uint32_t sQ_smem = __cvta_generic_to_shared(sQ);
uint32_t sK_smem = __cvta_generic_to_shared(sK);
if (tid == 0) {
printf("[qk] sQ_smem=0x%x sK_smem=0x%x sQ_align=%d sK_align=%d\n",
sQ_smem, sK_smem, sQ_smem % 16, sK_smem % 16);
}
__syncthreads();
uint64_t desc_q = make_umma_desc_bf16(sQ_smem, 128, HD, HD, UmmaMajor::MN);
uint64_t desc_k = make_umma_desc_bf16(sK_smem, 128, HD, HD, UmmaMajor::K);
if (tid == 0) {
printf("[qk] desc_q=0x%016llx desc_k=0x%016llx\n",
(unsigned long long)desc_q, (unsigned long long)desc_k);
}
__syncthreads();
// MMA is called by ONE lane (elect_one_sync pattern)
if (wid == 0 && lane == 0) {
umma_ss_f16(tmem_s, desc_q, desc_k, /*accumulate=*/false);