From 4fe7f9dc375a0fa914ef3585f35f7c2cc8bcc870 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Wed, 3 Jun 2026 00:24:20 +0000 Subject: [PATCH] Fix B1 FMHA: swap V matrix canonical layout args (dd, kk) not (kk, dd) ROOT CAUSE: canon_idx_bf16_16x16(kk, dd) was swapping the outer/inner group structure compared to the working TMA-loaded V layout in the multitile kernel. Working layout: (lr/8)*128 + (dd/8)*64 + (dd%8)*8 + (lr%8) B1 with (kk,dd): (dd/8)*128 + (kk/8)*64 + (kk%8)*8 + (dd%8) <- WRONG B1 with (dd,kk): (kk/8)*128 + (dd/8)*64 + (dd%8)*8 + (kk%8) <- CORRECT This caused the V matrix to be loaded into SMEM with transposed group structure, producing garbage output (cos=0.158 vs BF16 reference). --- dsv4/kernels/attention/fmha_mixed_fp8_decode.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dsv4/kernels/attention/fmha_mixed_fp8_decode.cuh b/dsv4/kernels/attention/fmha_mixed_fp8_decode.cuh index 65cd4b04..eb56251d 100644 --- a/dsv4/kernels/attention/fmha_mixed_fp8_decode.cuh +++ b/dsv4/kernels/attention/fmha_mixed_fp8_decode.cuh @@ -307,7 +307,7 @@ fmha_mixed_fp8_decode_kernel(FmhaMixedFp8DecodeParams p) { } // B is (K=16 rows, N=16 cols). Reuse BF16 canonical with rows=16 // by embedding into the first 16 rows of a 128-row tile; MMA_N=16. - sV[canon_idx_bf16_16x16(kk, dd)] = vbits; + sV[canon_idx_bf16_16x16(dd, kk)] = vbits; } __syncthreads();