From cfe21685d1dc9094e47cedb5a209821338764b06 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 03:51:58 +0000 Subject: [PATCH] =?UTF-8?q?WIP:=20tiled=20copy=20for=20P=E2=86=92SMEM=20(z?= =?UTF-8?q?ero=20fill)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dsv4/kernels/attention/fmha.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 79e3d645..39e73ebe 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -200,10 +200,11 @@ class FmhaKernel: tScS = qk_thr.partition_C(cS) tTMEM_LOADcS = thr_load.partition_D(tScS) - # P → SMEM: use PV A-operand partition for SMEM write - p_s = cute.slice_(p_smem_s,(None,None,None,0)) - tCrP_smem = pv_thr.partition_A(sP) # PV thread → SMEM partition for P (A operand) - tCrP_reg = cute.make_rmem_tensor(tCrP_smem.shape, self.q_dtype) + # P → SMEM copy setup + p_copy_atom = cute.make_copy_atom(cute.nvgpu.copy.AutoCopyOp(), self.q_dtype) + tiled_p_copy = cute.make_tiled_copy(p_copy_atom, tCrP_smem.layout, tCrP_reg.layout, tidx) + tPS_sP = tiled_p_copy.get_slice(sfw_idx).partition_D(tCrP_smem) + tPS_rP = tiled_p_copy.get_slice(sfw_idx).partition_S(tCrP_reg) # Online softmax state row_max = -Float32.inf @@ -240,17 +241,10 @@ class FmhaKernel: row_sum = row_sum + tTMEM_LOADrS_frg[k, j] # Write P to SMEM using PV A-operand thread partition - # Copy from softmax registers (QK partition) to SMEM (PV partition) - # Each thread converts its P values to BF16 and stores to its SMEM slot - rP_bf16_reg = cute.make_rmem_tensor(tCrP_smem.shape, self.q_dtype) - # Map QK partitioned P values to PV partitioned SMEM slots - # Simple approach: use cute.copy with the register and SMEM tensors - # The P SMEM is partitioned by pv_thr, softmax threads fill their portion - # For now: fill rP_bf16_reg from tTMEM_LOADrS (FP32→BF16 conversion) - for j in cutlass.range(cute.size(rP_bf16_reg), vectorize=True): - # TODO: proper element mapping from QK→PV partition - rP_bf16_reg[j] = BFloat16(0.0) - cute.copy(tCrP_smem, rP_bf16_reg) + # TODO: proper element mapping from QK→PV partition + for j in cutlass.range(cute.size(tCrP_reg), vectorize=True): + tCrP_reg[j] = BFloat16(0.0) + cute.copy(tiled_p_copy, tPS_rP, tPS_sP) cute.arch.fence_proxy("async.shared", space="cta") si_handle.release()