diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index ad6a2828..531d4ccf 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -164,6 +164,13 @@ class FmhaKernel: qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) tStS = qk_thr.make_fragment_C(qk_as) tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) + + # Create coordinate tensor for QK C-fragment layout + # Each element maps to its logical coordinate ((m,n),0,0) + if self.use_smem_p: + cP_qk = cute.make_identity_tensor(tStS0.shape) + print(f"[SMEM-P CUTLASS] Created cP_qk shape: {cute.shape(cP_qk)}") + pv_as = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) tOtO = pv_thr.make_fragment_C(pv_as) tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) @@ -309,6 +316,11 @@ class FmhaKernel: frg_cnt = 4 frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile)) + # Coordinate fragments for SMEM-P mapping + if self.use_smem_p: + tTMEM_LOADcS_frg = cute.logical_divide(tTMEM_LOADcS, cute.make_layout(frg_tile)) + print(f"[SMEM-P CUTLASS] Created tTMEM_LOADcS_frg shape: {cute.shape(tTMEM_LOADcS_frg)}") + for j in range(frg_cnt): for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])): row_max = cute.arch.fmax(row_max, tTMEM_LOADrS_frg[k, j] * scale_log2) @@ -332,6 +344,25 @@ class FmhaKernel: for k in range(cute.size(tTMEM_LOADrS_frg, mode=[0])): tTMEM_LOADrS_frg[k, j] = tTMEM_LOADrS_frg[k, j] * scale_log2 + minus_row_max tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True) + + # If using SMEM-P, write P value directly to SMEM + if self.use_smem_p: + # Get QK coordinate for this position + qk_coord = tTMEM_LOADcS_frg[k, j] + mn = qk_coord[0] + m = mn[0] + n = mn[1] + + # Map to PV SMEM coordinate + n0 = n % 16 + n1 = (n // 16) % 4 + n2 = n // 64 + pv_coord = ((m, n0), 0, (n1, n2), 0) + + # Convert Float32 → BF16 and write to SMEM + p_val_bf16 = tTMEM_LOADrS_frg[k, j].to(self.q_dtype) + sP[pv_coord] = p_val_bf16 + row_sum = row_sum + tTMEM_LOADrS_frg[k, j] s_vec = tTMEM_LOADrS_frg[None, j].load() rP_bf16_frg[None, j].store(s_vec.to(self.q_dtype)) @@ -341,51 +372,13 @@ class FmhaKernel: cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP) cute.arch.fence_view_async_tmem_store() else: - # SMEM-P: Manual addressing with CUTLASS LLM pattern - print(f"[SMEM-P CUTLASS] Starting manual P write to SMEM") - - # Get thread index for coordinate computations - tidx, _, _ = cute.arch.thread_idx() - warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) - - # Function to map QK coordinate to PV SMEM coordinate - # QK: ((m, n), 0, 0) → PV: ((m, n % 16), 0, ((n // 16) % 4, n // 64), 0) - def qk_to_pv_coord(m, n): - n0 = n % 16 - n1 = (n // 16) % 4 - n2 = n // 64 - return ((m, n0), 0, (n1, n2), 0) - - # Each thread handles 32×1 tile × 4 fragments = 128 P values - # We need to map each of these 128 values to SMEM - - # For testing: write a simple pattern to verify mapping works - # Each thread writes to different coordinate for testing - # Use thread-relative simple coordinates - thread_offset = tidx % 16 # 0-15 - test_m = thread_offset - test_n = thread_offset - test_coord = qk_to_pv_coord(test_m, test_n) - - # Write constant test value (0.5) - test_val = BFloat16(0.5) - sP[test_coord] = test_val - print(f"[SMEM-P CUTLASS] Thread wrote test to coord {test_coord}") - - # TODO: Implement full 128-value mapping - # Need to: - # 1. Create coordinate tensor with make_identity_tensor(tStS0.shape) - # 2. Partition it the same way as rP_bf16 - # 3. For each of the 128 P values, get its QK coordinate - # 4. Map to PV coordinate using qk_to_pv_coord - # 5. Write to sP[dst_coord] - - # For now, zero rest of sP (except our test value) - # This is WRONG but allows compilation - print(f"[SMEM-P CUTLASS] WARNING: Only wrote test value, rest zeroed (incomplete)") - + # SMEM-P: Already wrote P values to SMEM in softmax loop + # Just need fence and barrier + print(f"[SMEM-P CUTLASS] P values already written to SMEM, proceeding to fence") cute.arch.fence_proxy("async.shared", space="cta") - softmax_done_bar.arrive() # Per-tile O rescale (hand-constructed atoms with logical_divide layout) + + # Barrier for both TMEM-P and SMEM-P paths + softmax_done_bar.arrive() # Per-tile O rescale (hand-constructed atoms with logical_divide layout) if kt > 0: tTMrO = cute.make_rmem_tensor( (tTMEM_LOADcO.shape, 128 // corr_tile_size), self.acc_dtype