Fix B1 FMHA: swap V matrix canonical layout args (dd, kk) not (kk, dd)

ROOT CAUSE: canon_idx_bf16_16x16(kk, dd) was swapping the outer/inner group
structure compared to the working TMA-loaded V layout in the multitile kernel.

Working layout: (lr/8)*128 + (dd/8)*64 + (dd%8)*8 + (lr%8)
B1 with (kk,dd): (dd/8)*128 + (kk/8)*64 + (kk%8)*8 + (dd%8)  <- WRONG
B1 with (dd,kk): (kk/8)*128 + (dd/8)*64 + (dd%8)*8 + (kk%8)  <- CORRECT

This caused the V matrix to be loaded into SMEM with transposed group
structure, producing garbage output (cos=0.158 vs BF16 reference).
This commit is contained in:
2026-06-03 00:24:20 +00:00
parent 29a95a3db6
commit 4fe7f9dc37

View File

@@ -307,7 +307,7 @@ fmha_mixed_fp8_decode_kernel(FmhaMixedFp8DecodeParams p) {
}
// B is (K=16 rows, N=16 cols). Reuse BF16 canonical with rows=16
// by embedding into the first 16 rows of a 128-row tile; MMA_N=16.
sV[canon_idx_bf16_16x16(kk, dd)] = vbits;
sV[canon_idx_bf16_16x16(dd, kk)] = vbits;
}
__syncthreads();