Fix B1 FMHA: swap V matrix canonical layout args (dd, kk) not (kk, dd)
ROOT CAUSE: canon_idx_bf16_16x16(kk, dd) was swapping the outer/inner group structure compared to the working TMA-loaded V layout in the multitile kernel. Working layout: (lr/8)*128 + (dd/8)*64 + (dd%8)*8 + (lr%8) B1 with (kk,dd): (dd/8)*128 + (kk/8)*64 + (kk%8)*8 + (dd%8) <- WRONG B1 with (dd,kk): (kk/8)*128 + (dd/8)*64 + (dd%8)*8 + (kk%8) <- CORRECT This caused the V matrix to be loaded into SMEM with transposed group structure, producing garbage output (cos=0.158 vs BF16 reference).
This commit is contained in:
@@ -307,7 +307,7 @@ fmha_mixed_fp8_decode_kernel(FmhaMixedFp8DecodeParams p) {
|
||||
}
|
||||
// B is (K=16 rows, N=16 cols). Reuse BF16 canonical with rows=16
|
||||
// by embedding into the first 16 rows of a 128-row tile; MMA_N=16.
|
||||
sV[canon_idx_bf16_16x16(kk, dd)] = vbits;
|
||||
sV[canon_idx_bf16_16x16(dd, kk)] = vbits;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user