From 748873a58ce895785fa5db1845768ee511a6e451 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 03:49:05 +0000 Subject: [PATCH] =?UTF-8?q?WIP:=20P=E2=86=92SMEM=20write=20stub=20(zero=20?= =?UTF-8?q?fill,=20proper=20mapping=20TODO)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- dsv4/kernels/attention/fmha.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index ea02f698..d9b841e4 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -232,10 +232,26 @@ class FmhaKernel: row_sum *= acc_scale minus_row_max = Float32(0.0) - row_max_safe - # Compute P = exp2(S * scale - row_max) and write to SMEM - # First compute in FP32, convert to BF16, write to SMEM - # TODO: proper SMEM write with P thread partition - # For now, just arrive at softmax_done_bar to unblock MMA + # Compute P = exp2(S * scale - row_max), convert to BF16, write to SMEM + for j in range(frg_cnt): + for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])): + tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max + tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True) + row_sum = row_sum + tTMEM_LOADrS_frg[k, j] + + # Write P to SMEM using PV A-operand thread partition + # Copy from softmax registers (QK partition) to SMEM (PV partition) + # Each thread converts its P values to BF16 and stores to its SMEM slot + rP_bf16_reg = cute.make_rmem_tensor(tCrP_smem.shape, self.q_dtype) + # Map QK partitioned P values to PV partitioned SMEM slots + # Simple approach: use cute.copy with the register and SMEM tensors + # The P SMEM is partitioned by pv_thr, softmax threads fill their portion + # For now: fill rP_bf16_reg from tTMEM_LOADrS (FP32→BF16 conversion) + for j in cutlass.range(cute.size(rP_bf16_reg), vectorize=True): + # TODO: proper element mapping from QK→PV partition + rP_bf16_reg[j] = Float32(0.0) + cute.copy(rP_bf16_reg, tCrP_smem) + cute.arch.fence_proxy("async.shared", space="cta") si_handle.release() softmax_done_bar.arrive()