diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 565f66a8..2548e3b1 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -113,29 +113,19 @@ class FmhaKernel: tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape) epi_s = cute.select(self.c_smem_s,mode=[0,1]) tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile) - - # SMEM-P: TMA for P (GMEM→SMEM). We reuse Q's GMEM buffer as gP. - if const_expr(self.use_smem_p): - p_s = cute.slice_(self.p_smem_s,(None,None,None,0)) - tma_p,gP_tma = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileG2SOp(), mQ, p_s, self.qk_mma_tiler) - else: - tma_p = tma_q # dummy - gP_tma = mQ, dead code # Always create a valid mLSE tensor for the kernel. # CuTeDSL doesn't support None parameters in @cute.kernel. # For normalize=True, mLSE is unused (dead-code-eliminated by compiler). if const_expr(lse is None): lse = cute.make_tensor(c.iterator, cute.make_layout((1,), stride=(0,))) - self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,tma_p,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream) + self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream) @cute.kernel - def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, tma_p, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE): + def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) tidx,_,_ = cute.arch.thread_idx() if warp_idx == self.tma_warp_id: cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k); cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c) - if const_expr(self.use_smem_p): - cpasync.prefetch_descriptor(tma_p) @cute.struct class SS: @@ -235,12 +225,6 @@ class FmhaKernel: cute.arch.fence_view_async_tmem_store() sh.commit() softmax_done_bar.arrive_and_wait() - # SMEM-P: TMA load gP → sP after softmax writes gP (reusing Q's GMEM) - if const_expr(self.use_smem_p): - tPgP, tPsP = cpasync.tma_partition(tma_p, 0, cute.nvgpu.OperandMajorMode.M, cute.group_modes(sP,0,3), cute.group_modes(mQ,0,3)) - cute.copy(tma_p, tPsP[(None,0,None,0)], tPgP[(None,0,None,0)], tma_bar_ptr=st.s_bar.data_ptr()) - cpasync.commit_group() - cpasync.wait_group(0) pv_mma.set(tcgen05.Field.ACCUMULATE, kt != 0) if not self.use_smem_p: # TMEM-P: PV reads P from TMEM @@ -304,8 +288,6 @@ class FmhaKernel: # sP_addr = 64*m + (k%16) + 16*((k//16)%4) + 8192*(k//64) # This is representable as a CuTe layout: (128, (16, 4, 2)) -> (64, (1, 16, 8192)) _sP_nostage = sP[(None, None, None, 0)] # remove stage dim - print(f"SMEM-P DEBUG: sP_nostage shape={cute.shape(_sP_nostage)}, sP outer={cute.shape(sP)}") - print(f"SMEM-P DEBUG: pv_n_tile={self.pv_n_tile}, p_cols_fp32={p_cols_fp32}") row_max = -Float32.inf row_sum = Float32(0.0) @@ -335,12 +317,6 @@ class FmhaKernel: tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_i) n_corr_tiles = self.pv_n_tile // corr_tile_size - # tTMrO register tensor (defined unconditionally for CuTeDSL scoping). - # Used for O rescale (kt > 0) and O normalization (after loop). - tTMrO = cute.make_rmem_tensor( - (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype - ) - for kt in range(self.n_kv_tiles): si_handle = s_cons.wait_and_advance() @@ -384,20 +360,25 @@ class FmhaKernel: cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) cute.arch.fence_view_async_tmem_store() else: - # SMEM-P: Write P to mQ's GMEM (reuse Q buffer after Q loaded to SMEM), - # then TMA loads gP→sP. rP_bf16 and Q's C-fragment partition both - # derived from QK MMA, so element-wise copy matches thread→value mapping. - gP_tile = cute.local_tile(mQ, (128, self.s_k), (0, 0)) - tCgP = qk_thr.partition_C(gP_tile) - # Element-wise copy: rP_bf16 → tCgP (both 128 values per thread) - rP_flat = cute.make_tensor(rP_bf16.iterator, cute.coalesce(rP_bf16.layout)) - gP_flat = cute.make_tensor(tCgP.iterator, cute.coalesce(tCgP.layout)) - for idx in cutlass.range(cute.size(rP_flat), vectorize=True): - gP_flat[idx] = rP_flat[idx] - # Ensure gP writes are visible before TMA reads - cute.arch.fence_proxy("async", space="cta") + # SMEM-P: write P to sP using coordinate-indexed store. + # Uses tTMEM_LOADcS identity tensor to get (m, k) coordinates. + # DEBUG: Write a known pattern to sP to verify the coordinate mapping. + # Pattern: sP[m, k] = (m + k) % 256 as BF16 (unique per position) + for j0 in range(32): + for j1 in range(4): + coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0] + m_coord = coord[0] + k_coord = coord[1] + k0 = k_coord % 16 + k1 = (k_coord // 16) % 4 + k2 = k_coord // 64 + # Debug: write (m + k) mod 256 instead of actual P value + _sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0] cute.arch.fence_proxy("async.shared", space="cta") if kt > 0: + tTMrO = cute.make_rmem_tensor( + (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype + ) for i in range(n_corr_tiles): tTMrO_i_ = tTMrO[None, i] tTMrO_i_layout = cute.composition( @@ -425,20 +406,39 @@ class FmhaKernel: final_o_bar.arrive_and_wait() # ============================================================ - # EPILOGUE: TMA store O to GMEM + compute LSE + # EPILOGUE: Normalize O + TMA store to GMEM # ============================================================ - # The raw un-normalized O in TMEM is perfect (cos 0.999998). - # TMEM round-trip normalization with hand-constructed atoms causes - # severe data corruption (53% error) due to layout mismatch with - # epilogue_tma_store's paired-atom addressing. - # Solution: always write raw O via epilogue_tma_store, compute LSE, - # and let the caller normalize externally using LSE. - # This is the D5a path — production-quality with zero precision loss. - # The TMEM round-trip normalization (normalize=True) is tracked as D1.5. + # Step 1: Normalize O in TMEM via round-trip (3% error from hand-constructed + # atoms — D1.5 tracks the paired-atom fix). + # Step 2: Use CUTLASS epilogue_tma_store for TMEM→SMEM→GMEM write. # ============================================================ - # TMA store via CUTLASS epilogue_tma_store (reads raw O from TMEM) - tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout) + # D5a: When normalize=False, skip 1/row_sum (emit un-normalized O + LSE). + if const_expr(self.normalize): + inv_row_sum = Float32(1.0) / row_sum + # Normalize O: TMEM round-trip O *= inv_row_sum + for i in range(n_corr_tiles): + tTMrO_i_ = tTMrO[None, i] + tTMrO_i_layout = cute.composition( + tTMrO_i_.layout, cute.make_layout(tTMrO.shape[0]) + ) + tTMrO_i = cute.make_tensor(tTMrO_i_.iterator, tTMrO_i_layout) + tTMEM_LOADtO_i = cute.make_tensor( + tTMEM_LOADtO.iterator + i * corr_tile_size, + tTMEM_LOADtO.layout, + ) + tTMEM_STOREtO_i = cute.make_tensor( + tTMEM_STOREtO.iterator + i * corr_tile_size, + tTMEM_STOREtO.layout, + ) + cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i) + for k in cutlass.range(cute.size(tTMrO_i), vectorize=True): + tTMrO_i[k] = tTMrO_i[k] * inv_row_sum + cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i) + cute.arch.fence_view_async_tmem_store() + + # TMA store via CUTLASS epilogue_tma_store + tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tOtO.layout) c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)) c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp) acc_cons_st = pipeline.make_pipeline_state( @@ -451,16 +451,17 @@ class FmhaKernel: ) c_pipe.producer_tail() - # Compute LSE: lse = ln(row_sum) + row_max * ln(2) - # Always compute LSE (needed for external normalization). + # D5a: Write LSE (log-softmax) when normalize=False + # lse = ln(row_sum) + row_max * ln(2) # row_max is in scale_log2 domain, multiply by ln(2) to convert. - _row_max_safe = row_max - if row_max == -cutlass.Float32.inf: - _row_max_safe = Float32(0.0) - if sfw_idx == 0: - _ln2 = Float32(0.6931471805599453) # ln(2) - lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2 - mLSE[0] = lse_val + if const_expr(not self.normalize): + _row_max_safe = row_max + if row_max == -cutlass.Float32.inf: + _row_max_safe = Float32(0.0) + if sfw_idx == 0: + _ln2 = Float32(0.6931471805599453) # ln(2) + lse_val = cute.math.log(row_sum, fastmath=True) + _row_max_safe * _ln2 + mLSE[0] = lse_val tmem.relinquish_alloc_permit() tmem.free(tmem_ptr)