diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 952bd54d..a1bd452b 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -257,26 +257,12 @@ class FmhaKernel: tScP = cute.make_tensor(tScS.iterator, tScP_layout) tTMEM_STOREcP = thr_store.partition_S(tScP) - # P SMEM copy atoms: SMEM-P (always defined, only used when use_smem_p=True) - # Uses make_tiled_copy_C to partition threads by QK MMA's C-fragment layout. - # Softmax warps have P values in QK C-fragment layout (same as rP_bf16). - # This copy writes those values to sP which has PV A-operand SMEM layout. - # According to STAGE_D.md: use tcgen05.copy.St32x32bOp with Float32 (not BF16) - # and use make_tiled_copy_C(store_atom, qk_mma) to partition by QK C-fragment - # BUT: sP is BF16, so we need BF16 copy atom or convert Float32→BF16 - smem_copy_atom = cute.make_copy_atom( - tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), - self.q_dtype, # BF16 to match sP - num_bits_per_copy=128, - ) - tiled_smem_copy = cute.make_tiled_copy_C(smem_copy_atom, qk_mma) - thr_smem_copy = tiled_smem_copy.get_slice(sfw_idx) - # Debug sP shape - print(f"[SMEM-P PROPER] sP shape: {cute.shape(sP)} rank: {len(cute.shape(sP))}") - # Don't flatten - keep original sP - tSMEM_CPYsP = thr_smem_copy.partition_D(sP) # destination (SMEM) - print(f"[SMEM-P PROPER] After partition_D: tSMEM_CPYsP layout: {tSMEM_CPYsP.layout}") - print(f"[SMEM-P PROPER] After partition_D: tSMEM_CPYsP shape: {cute.shape(tSMEM_CPYsP)} rank: {len(cute.shape(tSMEM_CPYsP))}") + # Manual SMEM addressing for P (helpers are a trap) + # We need to write P values from QK C-fragment layout to PV A-operand SMEM layout + # sP has PV A-operand SMEM layout: p_smem_s + print(f"[SMEM-P MANUAL] Starting manual SMEM addressing") + print(f"[SMEM-P MANUAL] sP shape: {cute.shape(sP)} layout: {sP.layout}") + print(f"[SMEM-P MANUAL] p_smem_s (PV A-operand SMEM layout): {p_smem_s}") row_max = -Float32.inf row_sum = Float32(0.0) @@ -349,43 +335,39 @@ class FmhaKernel: cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) cute.arch.fence_view_async_tmem_store() else: - # SMEM-P: Use QK C-fragment layout for source (not TMEM layout) - # rP_bf16 uses tTMEM_LOADrS.layout (TMEM layout) causing rank mismatch - # Create BF16 view with QK C-fragment layout for copying to BF16 SMEM - rP_qk_layout = tStS0.layout # QK C-fragment layout for this thread - rP_bf16_qk = cute.make_tensor(cute.recast_ptr(rP_bf16.iterator, dtype=self.q_dtype), rP_qk_layout) + # SMEM-P: Manual addressing (helpers are a trap) + # Each softmax thread owns P values in QK C-fragment partition + # Need to write to SMEM with PV A-operand layout - # Partition source with QK layout - print(f"[SMEM-P PROPER] Before partition_S: rP_bf16_qk shape: {cute.shape(rP_bf16_qk)} layout: {rP_bf16_qk.layout}") - tSMEM_CPYrP_qk = thr_smem_copy.partition_S(rP_bf16_qk) - print(f"[SMEM-P PROPER] After partition_S: tSMEM_CPYrP_qk layout: {tSMEM_CPYrP_qk.layout}") - print(f"[SMEM-P PROPER] After partition_S: tSMEM_CPYrP_qk shape: {cute.shape(tSMEM_CPYrP_qk)} rank: {len(cute.shape(tSMEM_CPYrP_qk))}") + print(f"[SMEM-P MANUAL] Starting manual P write to SMEM") - # Debug shapes - print(f"[SMEM-P PROPER] rP_bf16 shape: {cute.shape(rP_bf16)}, layout: TMEM") - print(f"[SMEM-P PROPER] rP_bf16_qk shape: {cute.shape(rP_bf16_qk)}, layout: QK C-fragment (BF16)") - print(f"[SMEM-P PROPER] tSMEM_CPYrP_qk shape: {cute.shape(tSMEM_CPYrP_qk)} rank: {len(cute.shape(tSMEM_CPYrP_qk))}") - print(f"[SMEM-P PROPER] tSMEM_CPYsP shape: {cute.shape(tSMEM_CPYsP)} rank: {len(cute.shape(tSMEM_CPYsP))}") + # Get thread's logical position in QK C-fragment partition + # tStS0 is QK C-fragment tensor for this thread + thread_qk_coord = cute.coord(tStS0) # Logical coordinates in QK C-fragment space + print(f"[SMEM-P MANUAL] Thread QK coord: {thread_qk_coord}") - # Manual copy instead of cute.copy (helpers are a trap) - # Get sizes and iterate - src_size = cute.size(tSMEM_CPYrP_qk) - dst_size = cute.size(tSMEM_CPYsP) - print(f"[SMEM-P PROPER] Manual copy: src_size={src_size}, dst_size={dst_size}") + # Get the shape of P values this thread owns + # tStS0 has shape ((128, 128), 1, 1) - total 128×128 P matrix + # Each thread owns a subtile of this + qk_fragment_shape = cute.shape(tStS0) + print(f"[SMEM-P MANUAL] QK fragment shape: {qk_fragment_shape}") - if src_size == dst_size: - # Same number of elements - copy element by element - for j in cutlass.range(src_size, vectorize=True): - val = tSMEM_CPYrP_qk[j] - tSMEM_CPYsP[j] = val - print(f"[SMEM-P PROPER] Manual copy completed (src_size==dst_size)") - else: - print(f"[SMEM-P PROPER] Manual copy: size mismatch, using fallback stub") - # Fallback to stub for now - for j in cutlass.range(cute.size(sP), vectorize=True): - sP[j] = BFloat16(0.0) - print(f"[SMEM-P PROPER] Used fallback stub") + # For now, implement simple test: write thread ID to first element + # This is WRONG but helps debug + smem_test_offset = Int32(0) + sP[smem_test_offset] = BFloat16(float(sfw_idx) * 0.01) + print(f"[SMEM-P MANUAL] Wrote thread {sfw_idx} value to SMEM offset 0") + # TODO: Implement proper mapping + # 1. Determine which P indices this thread owns + # 2. For each P index (i,j), compute SMEM address in PV A-operand layout + # 3. Write P value to that address + + # For now, zero out sP as fallback (produces wrong results but compiles) + for j in cutlass.range(cute.size(sP), vectorize=True): + sP[j] = BFloat16(0.0) + + print(f"[SMEM-P MANUAL] Used zero-fallback (TODO: implement proper mapping)") cute.arch.fence_proxy("async.shared", space="cta") softmax_done_bar.arrive() # Per-tile O rescale (hand-constructed atoms with logical_divide layout) if kt > 0: