WIP: make_tiled_copy_C for P→SMEM

This commit is contained in:
2026-05-23 03:56:56 +00:00
parent cf080ccf00
commit 5af491cd73

View File

@@ -200,11 +200,12 @@ class FmhaKernel:
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
# P → SMEM copy setup
p_copy_atom = cute.make_copy_atom(cpasync.CopyOp(), self.q_dtype)
tiled_p_copy = cute.make_tiled_copy(p_copy_atom, tCrP_smem.layout, tCrP_reg.layout, tidx)
tPS_sP = tiled_p_copy.get_slice(sfw_idx).partition_D(tCrP_smem)
tPS_rP = tiled_p_copy.get_slice(sfw_idx).partition_S(tCrP_reg)
# P → SMEM: use make_tiled_copy_C for register→SMEM (standard epilogue pattern)
# The P values are the A operand of PV, written to SMEM so the MMA can read them
p_s = cute.slice_(p_smem_s,(None,None,None,0))
tCrP_smem = pv_thr.partition_A(sP) # PV thread → SMEM partition for P (A operand)
tCrP_reg = pv_mma.make_fragment_A(sP) # register fragment matching SMEM layout
tiled_p_copy = cute.make_tiled_copy_C(pv_mma, tCrP_smem, p_s, 1)
# Online softmax state
row_max = -Float32.inf
@@ -240,11 +241,11 @@ class FmhaKernel:
tTMEM_LOADrS_frg[k, j] = cute.math.exp2(tTMEM_LOADrS_frg[k, j], fastmath=True)
row_sum = row_sum + tTMEM_LOADrS_frg[k, j]
# Write P to SMEM using PV A-operand thread partition
# Write P to SMEM using PV A-operand partition
# TODO: proper element mapping from QK→PV partition
for j in cutlass.range(cute.size(tCrP_reg), vectorize=True):
tCrP_reg[j] = BFloat16(0.0)
cute.copy(tiled_p_copy, tPS_rP, tPS_sP)
cute.copy(tiled_p_copy, tCrP_reg, tCrP_smem)
cute.arch.fence_proxy("async.shared", space="cta")
si_handle.release()