Merge branch 'master' of ssh://sweetapi.com:2222/biondizzle/nvfp4-megamoe-kernel

This commit is contained in:
2026-05-24 03:23:22 +00:00

View File

@@ -366,19 +366,21 @@ class FmhaKernel:
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()
else:
# SMEM-P: write P to sP using coordinate-indexed store.
# Uses tTMEM_LOADcS identity tensor to get (m, k) coordinates
# and maps them to sP's swizzled layout.
for j0 in range(32):
for j1 in range(4):
coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0]
m_coord = coord[0]
k_coord = coord[1]
k0 = k_coord % 16
k1 = (k_coord // 16) % 4
k2 = k_coord // 64
_sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0]
cute.arch.fence_proxy("async.shared", space="cta")
# SMEM-P: write P to sP using TiledCopy derived from QK MMA.
#
# make_tiled_copy_C with qk_mma gives threads partitioned by the
# QK C-fragment (same as TMEM load). Source: rP_bf16 (registers).
# Destination: sP (PV A-operand SMEM layout).
_smem_p_store_atom = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
self.q_dtype,
num_bits_per_copy=16,
)
_tiled_smem_p = cute.make_tiled_copy_C(_smem_p_store_atom, qk_mma)
_thr_smem_p = _tiled_smem_p.get_slice(sfw_idx)
_tRS_rP = _thr_smem_p.partition_S(rP_bf16)
_tRS_sP = _thr_smem_p.partition_D(_sP_nostage)
cute.copy(_tiled_smem_p, _tRS_rP, _tRS_sP)
cute.arch.fence_proxy("async.shared", space="cta")
if kt > 0:
for i in range(n_corr_tiles):