Merge branch 'master' of ssh://sweetapi.com:2222/biondizzle/nvfp4-megamoe-kernel
This commit is contained in:
@@ -366,19 +366,21 @@ class FmhaKernel:
|
||||
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
else:
|
||||
# SMEM-P: write P to sP using coordinate-indexed store.
|
||||
# Uses tTMEM_LOADcS identity tensor to get (m, k) coordinates
|
||||
# and maps them to sP's swizzled layout.
|
||||
for j0 in range(32):
|
||||
for j1 in range(4):
|
||||
coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0]
|
||||
m_coord = coord[0]
|
||||
k_coord = coord[1]
|
||||
k0 = k_coord % 16
|
||||
k1 = (k_coord // 16) % 4
|
||||
k2 = k_coord // 64
|
||||
_sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0]
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
# SMEM-P: write P to sP using TiledCopy derived from QK MMA.
|
||||
#
|
||||
# make_tiled_copy_C with qk_mma gives threads partitioned by the
|
||||
# QK C-fragment (same as TMEM load). Source: rP_bf16 (registers).
|
||||
# Destination: sP (PV A-operand SMEM layout).
|
||||
_smem_p_store_atom = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
self.q_dtype,
|
||||
num_bits_per_copy=16,
|
||||
)
|
||||
_tiled_smem_p = cute.make_tiled_copy_C(_smem_p_store_atom, qk_mma)
|
||||
_thr_smem_p = _tiled_smem_p.get_slice(sfw_idx)
|
||||
_tRS_rP = _thr_smem_p.partition_S(rP_bf16)
|
||||
_tRS_sP = _thr_smem_p.partition_D(_sP_nostage)
|
||||
cute.copy(_tiled_smem_p, _tRS_rP, _tRS_sP)
|
||||
cute.arch.fence_proxy("async.shared", space="cta")
|
||||
if kt > 0:
|
||||
for i in range(n_corr_tiles):
|
||||
|
||||
Reference in New Issue
Block a user