D1.3: Implement SMEM-P path (write P to SMEM via tiled_smem_copy instead of zeroing sP)

This commit is contained in:
2026-05-23 09:19:35 +00:00
parent 1d1de22775
commit 162bf51d64

View File

@@ -15,7 +15,7 @@ import math
class FmhaKernel:
def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None):
def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None):
self.head_dim = head_dim
self.s_k = s_k
self.n_kv_tiles = s_k // 128
@@ -31,6 +31,7 @@ class FmhaKernel:
self.kv_stage = 2; self.q_stage = 1; self.num_c_stage = 2
self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim)
self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e)
def _setup(self, qk_mma, pv_mma):
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (128, 128, qk_ik * 4)
@@ -93,9 +94,6 @@ class FmhaKernel:
(self.pv_n_tile, self.s_k, 1),
stride=(1, self.pv_n_tile, self.pv_n_tile * self.s_k),
),
)
stride=(1, self.head_dim, self.head_dim * self.s_k),
),
)
self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
@@ -344,11 +342,14 @@ class FmhaKernel:
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()
else:
# SMEM-P: TODO — write P to SMEM via make_tiled_copy_C(store_atom, qk_mma)
# For now, zero sP as stub. PV will produce garbage with SMEM-P path.
for j in cutlass.range(cute.size(sP), vectorize=True):
sP[j] = BFloat16(0.0)
# SMEM-P: write P to SMEM via tiled_smem_copy
# rP_bf16 contains P values in QK C-fragment layout (BF16)
# Flatten to 2D for copy operation
rP_bf16_2d = cute.group_modes(rP_bf16, 0, 2)
tSMEM_CPYrP = thr_smem_copy.partition_S(rP_bf16_2d)
cute.copy(tiled_smem_copy, tSMEM_CPYrP, tSMEM_CPYsP)
cute.arch.fence_proxy("async.shared", space="cta")
softmax_done_bar.arrive()
# Per-tile O rescale (hand-constructed atoms with logical_divide layout)
if kt > 0: