diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 14bad065..95b56f7c 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -15,7 +15,7 @@ import math class FmhaKernel: - def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None): + def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None): self.head_dim = head_dim self.s_k = s_k self.n_kv_tiles = s_k // 128 @@ -31,6 +31,7 @@ class FmhaKernel: self.kv_stage = 2; self.q_stage = 1; self.num_c_stage = 2 self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim) self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e) + def _setup(self, qk_mma, pv_mma): qk_ik = cute.size(qk_mma.shape_mnk, mode=[2]) self.qk_mma_tiler = (128, 128, qk_ik * 4) @@ -93,9 +94,6 @@ class FmhaKernel: (self.pv_n_tile, self.s_k, 1), stride=(1, self.pv_n_tile, self.pv_n_tile * self.s_k), ), - ) - stride=(1, self.head_dim, self.head_dim * self.s_k), - ), ) self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode() self.c_layout = LayoutEnum.from_tensor(c) @@ -344,11 +342,14 @@ class FmhaKernel: cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) cute.arch.fence_view_async_tmem_store() else: - # SMEM-P: TODO — write P to SMEM via make_tiled_copy_C(store_atom, qk_mma) - # For now, zero sP as stub. PV will produce garbage with SMEM-P path. - for j in cutlass.range(cute.size(sP), vectorize=True): - sP[j] = BFloat16(0.0) + # SMEM-P: write P to SMEM via tiled_smem_copy + # rP_bf16 contains P values in QK C-fragment layout (BF16) + # Flatten to 2D for copy operation + rP_bf16_2d = cute.group_modes(rP_bf16, 0, 2) + tSMEM_CPYrP = thr_smem_copy.partition_S(rP_bf16_2d) + cute.copy(tiled_smem_copy, tSMEM_CPYrP, tSMEM_CPYsP) cute.arch.fence_proxy("async.shared", space="cta") + softmax_done_bar.arrive() # Per-tile O rescale (hand-constructed atoms with logical_divide layout) if kt > 0: