From b4b11db0fa1dd37bdde62258b4d240ec917bbcca Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 19:12:13 +0000 Subject: [PATCH] Fix SMEM-P: use BF16 copy atom and BF16 source with QK C-fragment layout --- dsv4/kernels/attention/fmha.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 643e0d40..9bd72ae3 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -263,9 +263,10 @@ class FmhaKernel: # This copy writes those values to sP which has PV A-operand SMEM layout. # According to STAGE_D.md: use tcgen05.copy.St32x32bOp with Float32 (not BF16) # and use make_tiled_copy_C(store_atom, qk_mma) to partition by QK C-fragment + # BUT: sP is BF16, so we need BF16 copy atom or convert Float32→BF16 smem_copy_atom = cute.make_copy_atom( tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), - Float32, + self.q_dtype, # BF16 to match sP num_bits_per_copy=128, ) tiled_smem_copy = cute.make_tiled_copy_C(smem_copy_atom, qk_mma) @@ -346,16 +347,16 @@ class FmhaKernel: else: # SMEM-P: Use QK C-fragment layout for source (not TMEM layout) # rP_bf16 uses tTMEM_LOADrS.layout (TMEM layout) causing rank mismatch - # Create view with QK C-fragment layout (tStS0.layout) using Float32 source (rP_words) + # Create BF16 view with QK C-fragment layout for copying to BF16 SMEM rP_qk_layout = tStS0.layout # QK C-fragment layout for this thread - rP_qk = cute.make_tensor(cute.recast_ptr(rP_words.iterator, dtype=Float32), rP_qk_layout) + rP_bf16_qk = cute.make_tensor(cute.recast_ptr(rP_bf16.iterator, dtype=self.q_dtype), rP_qk_layout) # Partition source with QK layout - tSMEM_CPYrP_qk = thr_smem_copy.partition_S(rP_qk) + tSMEM_CPYrP_qk = thr_smem_copy.partition_S(rP_bf16_qk) # Debug shapes print(f"[SMEM-P PROPER] rP_bf16 shape: {cute.shape(rP_bf16)}, layout: TMEM") - print(f"[SMEM-P PROPER] rP_qk shape: {cute.shape(rP_qk)}, layout: QK C-fragment (Float32)") + print(f"[SMEM-P PROPER] rP_bf16_qk shape: {cute.shape(rP_bf16_qk)}, layout: QK C-fragment (BF16)") print(f"[SMEM-P PROPER] tSMEM_CPYrP_qk shape: {cute.shape(tSMEM_CPYrP_qk)} rank: {len(cute.shape(tSMEM_CPYrP_qk))}") print(f"[SMEM-P PROPER] tSMEM_CPYsP shape: {cute.shape(tSMEM_CPYsP)} rank: {len(cute.shape(tSMEM_CPYsP))}")