debug: print UMMA descriptor values for diagnosis
This commit is contained in:
@@ -78,9 +78,21 @@ fmha_qk_verify(
|
||||
uint32_t sQ_smem = __cvta_generic_to_shared(sQ);
|
||||
uint32_t sK_smem = __cvta_generic_to_shared(sK);
|
||||
|
||||
if (tid == 0) {
|
||||
printf("[qk] sQ_smem=0x%x sK_smem=0x%x sQ_align=%d sK_align=%d\n",
|
||||
sQ_smem, sK_smem, sQ_smem % 16, sK_smem % 16);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
uint64_t desc_q = make_umma_desc_bf16(sQ_smem, 128, HD, HD, UmmaMajor::MN);
|
||||
uint64_t desc_k = make_umma_desc_bf16(sK_smem, 128, HD, HD, UmmaMajor::K);
|
||||
|
||||
if (tid == 0) {
|
||||
printf("[qk] desc_q=0x%016llx desc_k=0x%016llx\n",
|
||||
(unsigned long long)desc_q, (unsigned long long)desc_k);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// MMA is called by ONE lane (elect_one_sync pattern)
|
||||
if (wid == 0 && lane == 0) {
|
||||
umma_ss_f16(tmem_s, desc_q, desc_k, /*accumulate=*/false);
|
||||
|
||||
Reference in New Issue
Block a user