diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index c40d9c7d..d31e0f2d 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -87,7 +87,7 @@ class FmhaKernel: cute.size_in_bytes(self.q_dtype, v_s)) * cta @cute.jit - def __call__(self, q, k, v, c, stream): + def __call__(self, q, k, v, c, stream, lse=None): self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype self.a_major = LayoutEnum.from_tensor(q).mma_major_mode() self.b_major = LayoutEnum.from_tensor(k).mma_major_mode() @@ -111,10 +111,10 @@ class FmhaKernel: tma_v,mV = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,pv_mma.thr_id),v_fmha,v_s,self.pv_mma_tiler,pv_mma,self.cluster_layout_vmnk.shape) epi_s = cute.select(self.c_smem_s,mode=[0,1]) tma_c,mC = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(),c,epi_s,self.epi_tile) - self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream) + self._kernel(qk_mma,pv_mma,tma_q,mQ,tma_k,mK,tma_v,mV,tma_c,mC,self.cluster_layout_vmnk,self.q_smem_s,self.k_smem_s,self.v_smem_s,self.p_tmem_s,self.p_smem_s,self.c_smem_s,self.epi_tile,lse).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream) @cute.kernel - def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile): + def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, q_smem_s, k_smem_s, v_smem_s, p_tmem_s, p_smem_s, c_smem_s, epi_tile, mLSE): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) tidx,_,_ = cute.arch.thread_idx() if warp_idx == self.tma_warp_id: @@ -264,19 +264,10 @@ class FmhaKernel: tScP = cute.make_tensor(tScS.iterator, tScP_layout) tTMEM_STOREcP = thr_store.partition_S(tScP) - # P SMEM copy atoms: SMEM-P (always defined, only used when use_smem_p=True) - # Uses make_tiled_copy_C to partition threads by QK MMA's C-fragment layout. - # Softmax warps have P values in QK C-fragment layout (same as rP_bf16). - # This copy writes those values to sP which has PV A-operand SMEM layout. - smem_copy_atom = cute.make_copy_atom( - cute.nvgpu.CopyUniversalOp(), - self.q_dtype, - num_bits_per_copy=128, - ) - tiled_smem_copy = cute.make_tiled_copy_C(smem_copy_atom, qk_mma) - thr_smem_copy = tiled_smem_copy.get_slice(sfw_idx) - sP_2d = cute.group_modes(sP, 0, 3) - tSMEM_CPYsP = thr_smem_copy.partition_D(sP_2d) # destination (SMEM) + # P SMEM copy atoms: SMEM-P (TBD) + # make_tiled_copy_C gives rank mismatch (QK C-fragment has 4 modes, + # PV A-operand SMEM has 3 modes). Need proper layout-aware copy. + # For now, SMEM-P path zero-fills sP. TMEM-P (hd<=64) works correctly. row_max = -Float32.inf row_sum = Float32(0.0) @@ -349,34 +340,13 @@ class FmhaKernel: cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) cute.arch.fence_view_async_tmem_store() else: - # SMEM-P: Use QK C-fragment layout for source (not TMEM layout) - # rP_bf16 uses tTMEM_LOADrS.layout (TMEM layout) causing rank mismatch - # Create view with QK C-fragment layout (tStS0.layout) - rP_qk_layout = tStS0.layout # QK C-fragment layout for this thread - rP_qk = cute.make_tensor(cute.recast_ptr(rP_bf16.iterator, dtype=self.q_dtype), rP_qk_layout) - - # Partition source with QK layout - tSMEM_CPYrP_qk = thr_smem_copy.partition_S(rP_qk) - - # Debug shapes - print(f"[SMEM-P PROPER] rP_bf16 shape: {cute.shape(rP_bf16)}, layout: TMEM") - print(f"[SMEM-P PROPER] rP_qk shape: {cute.shape(rP_qk)}, layout: QK C-fragment") - print(f"[SMEM-P PROPER] tSMEM_CPYrP_qk shape: {cute.shape(tSMEM_CPYrP_qk)} rank: {len(cute.shape(tSMEM_CPYrP_qk))}") - print(f"[SMEM-P PROPER] tSMEM_CPYsP shape: {cute.shape(tSMEM_CPYsP)} rank: {len(cute.shape(tSMEM_CPYsP))}") - - # Attempt copy with correct layout - try: - cute.copy(tiled_smem_copy, tSMEM_CPYrP_qk, tSMEM_CPYsP) - print(f"[SMEM-P PROPER] Copy succeeded with QK C-fragment layout") - except Exception as e: - print(f"[SMEM-P PROPER] Copy failed: {e}") - # Fallback to stub for now - for j in cutlass.range(cute.size(sP), vectorize=True): - sP[j] = BFloat16(0.0) - print(f"[SMEM-P PROPER] Used fallback stub") - + # SMEM-P: zero-fill sP stub (proper layout-aware copy TBD) + # make_tiled_copy_C gives rank mismatch (4 vs 3). + # Need proper P register→SMEM copy that respects QK C-fragment layout + # and PV A-operand SMEM layout. For now, TMEM-P (hd<=64) works. + for j in cutlass.range(cute.size(sP), vectorize=True): + sP[j] = self.q_dtype(0) cute.arch.fence_proxy("async.shared", space="cta") - softmax_done_bar.arrive() # Per-tile O rescale (hand-constructed atoms with logical_divide layout) if kt > 0: tTMrO = cute.make_rmem_tensor( (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype @@ -430,7 +400,9 @@ class FmhaKernel: cute.arch.fence_view_async_tmem_store() # === Final O normalization: O *= 1/row_sum === - inv_row_sum = Float32(1.0) / row_sum + # D5a: When normalize=False, skip normalization (emit un-normalized O + lse) + if const_expr(self.normalize): + inv_row_sum = Float32(1.0) / row_sum tTMrO = cute.make_rmem_tensor( (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype @@ -450,8 +422,9 @@ class FmhaKernel: ) cute.copy(tiled_tmem_load_o, tTMEM_LOADtO_i, tTMrO_i) - for j in cutlass.range(cute.size(tTMrO_i), vectorize=True): - tTMrO_i[j] = tTMrO_i[j] * inv_row_sum + if const_expr(self.normalize): + for j in cutlass.range(cute.size(tTMrO_i), vectorize=True): + tTMrO_i[j] = tTMrO_i[j] * inv_row_sum cute.copy(tiled_tmem_store_o, tTMrO_i, tTMEM_STOREtO_i) cute.arch.fence_view_async_tmem_store() @@ -470,5 +443,15 @@ class FmhaKernel: ) c_pipe.producer_tail() + # D5a: Write LSE (log-softmax) when normalize=False + # lse = log(row_sum) + row_max (row_max in scaled domain) + # Only thread 0 of the epilogue warps writes LSE for this tile. + # For M=1 decode: one lse value per query row. + if const_expr(not self.normalize): + if mLSE is not None: + if sfw_idx == 0: + lse_val = cute.math.log(row_sum, fastmath=True) + row_max_safe + mLSE[None, None, 0] = lse_val + tmem.relinquish_alloc_permit() tmem.free(tmem_ptr) diff --git a/tests/unit/test_fmha_v3_stage_d1.py b/tests/unit/test_fmha_v3_stage_d1.py index 985b67fe..404eaf65 100644 --- a/tests/unit/test_fmha_v3_stage_d1.py +++ b/tests/unit/test_fmha_v3_stage_d1.py @@ -98,6 +98,8 @@ def test(): cos64 = test_head_dim(64, 128) # hd=256: single PV tile at MMA instruction max + # NOTE: SMEM-P path is a stub (zero-fill), so hd>64 will FAIL + # until the proper P register→SMEM copy is implemented. print("\n--- HEAD_DIM=256 (single PV tile) ---") cos256 = test_head_dim(256, 128) @@ -105,11 +107,70 @@ def test(): print("\n--- HEAD_DIM=512 (2 PV tiles) ---") cos512 = test_head_dim(512, 128) + # D5a: normalize=False with LSE output + print("\n--- D5a: normalize=False, LSE output (hd=64) ---") + hd = 64; n_kv = 128; m = 128 + torch.manual_seed(42) + q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') + k = torch.randn(n_kv, hd, 1, dtype=torch.bfloat16, device='cuda') + v = torch.randn(n_kv, hd, dtype=torch.bfloat16, device='cuda') + c_unnorm = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda') + lse_tensor = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') + + # FP32 reference + qf = q[:, :, 0].float() + kf = k[:, :, 0].float() + scale = 1.0 / math.sqrt(hd) + attn = qf @ kf.T * scale + # Compute reference LSE: log(sum(exp(attn - max))) + attn_max = attn.max(dim=-1, keepdim=True)[0] + attn_exp = torch.exp(attn - attn_max) + attn_sum = attn_exp.sum(dim=-1, keepdim=True) + ref_lse = torch.log(attn_sum.squeeze(-1)) + attn_max.squeeze(-1) # (m,) + ref_attn = attn_exp / attn_sum + ref = ref_attn @ v.float() + # Un-normalized reference: O_unnorm = sum(P * V) (no 1/row_sum) + ref_unnorm = attn_exp @ v.float() # un-normalized + + kernel = FmhaKernel(head_dim=hd, s_k=n_kv, normalize=False) + pv_n_tile = kernel.pv_n_tile + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + v_tile = v[:, 0:pv_n_tile].contiguous().unsqueeze(-1) + c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') + + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + mV = ct.from_dlpack(v_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_tile)) + mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile)) + mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor)) + + print('Compiling normalize=False kernel...', flush=True) + compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE) + + compiled(mQ, mK, mV, mC, stream, mLSE) + torch.cuda.synchronize() + + out_unnorm = c_tile[:, :, 0].float() + lse_out = lse_tensor[0, 0, 0].item() + + # Verify un-normalized output matches reference + cos_unnorm = torch.nn.functional.cosine_similarity( + out_unnorm.flatten().unsqueeze(0), ref_unnorm.flatten().unsqueeze(0) + ).item() + # Verify LSE matches reference (first row) + ref_lse_val = ref_lse[0].item() + lse_err = abs(lse_out - ref_lse_val) + print(f' Un-norm O: cos {cos_unnorm:.6f} (should be >= 0.97)') + print(f' LSE: kernel={lse_out:.6f} ref={ref_lse_val:.6f} err={lse_err:.6f}') + # Summary print("\n=== Summary ===") print(f"hd=64, n=128: cos={cos64:.6f} {'PASS' if cos64 >= 0.97 else 'FAIL'}") print(f"hd=256, n=128: cos={cos256:.6f} {'PASS' if cos256 >= 0.97 else 'FAIL'}") print(f"hd=512, n=128: cos={cos512:.6f} {'PASS' if cos512 >= 0.97 else 'FAIL'}") + print(f"D5a unnorm: cos={cos_unnorm:.6f} lse_err={lse_err:.6f}") if __name__ == '__main__':