D1.3: Implement SMEM-P path (write P to SMEM via tiled_smem_copy instead of zeroing sP)
This commit is contained in:
@@ -15,7 +15,7 @@ import math
|
||||
|
||||
|
||||
class FmhaKernel:
|
||||
def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None):
|
||||
def __init__(self, head_dim=64, s_k=128, scale_softmax=None, use_smem_p=None):
|
||||
self.head_dim = head_dim
|
||||
self.s_k = s_k
|
||||
self.n_kv_tiles = s_k // 128
|
||||
@@ -31,6 +31,7 @@ class FmhaKernel:
|
||||
self.kv_stage = 2; self.q_stage = 1; self.num_c_stage = 2
|
||||
self.scale_softmax = scale_softmax if scale_softmax is not None else 1.0 / math.sqrt(self.head_dim)
|
||||
self.scale_softmax_log2 = self.scale_softmax * math.log2(math.e)
|
||||
|
||||
def _setup(self, qk_mma, pv_mma):
|
||||
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (128, 128, qk_ik * 4)
|
||||
@@ -93,9 +94,6 @@ class FmhaKernel:
|
||||
(self.pv_n_tile, self.s_k, 1),
|
||||
stride=(1, self.pv_n_tile, self.pv_n_tile * self.s_k),
|
||||
),
|
||||
)
|
||||
stride=(1, self.head_dim, self.head_dim * self.s_k),
|
||||
),
|
||||
)
|
||||
self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode()
|
||||
self.c_layout = LayoutEnum.from_tensor(c)
|
||||
@@ -344,11 +342,14 @@ class FmhaKernel:
|
||||
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
else:
|
||||
# SMEM-P: TODO — write P to SMEM via make_tiled_copy_C(store_atom, qk_mma)
|
||||
# For now, zero sP as stub. PV will produce garbage with SMEM-P path.
|
||||
for j in cutlass.range(cute.size(sP), vectorize=True):
|
||||
sP[j] = BFloat16(0.0)
|
||||
# SMEM-P: write P to SMEM via tiled_smem_copy
|
||||
# rP_bf16 contains P values in QK C-fragment layout (BF16)
|
||||
# Flatten to 2D for copy operation
|
||||
rP_bf16_2d = cute.group_modes(rP_bf16, 0, 2)
|
||||
tSMEM_CPYrP = thr_smem_copy.partition_S(rP_bf16_2d)
|
||||
cute.copy(tiled_smem_copy, tSMEM_CPYrP, tSMEM_CPYsP)
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
softmax_done_bar.arrive()
|
||||
|
||||
# Per-tile O rescale (hand-constructed atoms with logical_divide layout)
|
||||
if kt > 0:
|
||||
|
||||
Reference in New Issue
Block a user