Merge branch 'master' of ssh://sweetapi.com:2222/biondizzle/nvfp4-megamoe-kernel

This commit is contained in:
2026-05-24 02:41:39 +00:00

View File

@@ -366,18 +366,40 @@ class FmhaKernel:
cute.copy(tiled_tmem_store, rP_words, tTMEM_STOREtP)
cute.arch.fence_view_async_tmem_store()
else:
# SMEM-P: write P to sP using coordinate-indexed store.
# Uses tTMEM_LOADcS identity tensor to get (m, k) coordinates.
for j0 in range(32):
for j1 in range(4):
coord = tTMEM_LOADcS[(j0, 0), j1, 0, 0]
m_coord = coord[0]
k_coord = coord[1]
k0 = k_coord % 16
k1 = (k_coord // 16) % 4
k2 = k_coord // 64
_sP_nostage[(m_coord, k0), 0, (k1, k2)] = rP_bf16[(j0, 0), j1, 0, 0]
cute.arch.fence_proxy("async.shared", space="cta")
# SMEM-P: write P to sP using make_cotiled_copy.
# Build atom_layout_tv: (tid, vid) -> sP_addr.
# Each softmax thread (128 total) owns one row of the 128×128 P matrix.
# Within a row, 128 values decompose as (k0, k1, k2) = (16, 4, 2)
# matching sP's subtiled layout strides (1, 16, 8192).
# Thread stride = 64 (one row in sP's layout).
#
# atom_layout_tv(tid, k0, k1, k2) = 64*tid + k0 + 16*k1 + 8192*k2
# sP_addr(m, k0, k1, k2) = 64*m + k0 + 16*k1 + 8192*k2
# These are the SAME function — tid maps to m naturally.
_smem_p_atom = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
self.q_dtype,
num_bits_per_copy=16,
)
_smem_p_atom_layout_tv = cute.make_layout(
(128, (16, 4, 2)),
stride=(64, (1, 16, 8192)),
)
# Build sP data layout (coalesced, no swizzle, no stage dim)
# sP_outer has shape ((128,16),1,(4,2),1) with strides ((64,1),0,(16,8192),0)
# Coalesced: ((128,16,4,2),(64,1,16,8192))
_smem_p_data_layout = cute.make_layout(
(128, 16, 4, 2),
stride=(64, 1, 16, 8192),
)
_tiled_smem_p = cute.make_cotiled_copy(
_smem_p_atom, _smem_p_atom_layout_tv, _smem_p_data_layout,
)
_thr_smem_p = _tiled_smem_p.get_slice(sfw_idx)
_tRS_sP = _thr_smem_p.partition_D(_sP_nostage)
_tRS_rP = _tiled_smem_p.retile(rP_bf16)
cute.copy(_tiled_smem_p, _tRS_rP, _tRS_sP)
cute.arch.fence("shared", space="cta")
if kt > 0:
for i in range(n_corr_tiles):
tTMrO_i_ = tTMrO[None, i]