Files
nvfp4-megamoe-kernel/tests/unit/test_d1_3_layout_diag.py

198 lines
10 KiB
Python

"""
D1.3 SMEM-P diagnostic: print ALL relevant shapes and layouts.
Runs layout analysis inside a JIT-compiled function to get MLIR context.
"""
import torch, math, sys
import cutlass, cutlass.cute as cute, cutlass.utils as utils
from cutlass.cute.nvgpu import tcgen05, cpasync
from cutlass import Float32, BFloat16, Int32, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cutlass.torch as ct
def main():
for hd in [64, 256]:
print(f"\n{'='*60}")
print(f" HEAD_DIM={hd}")
print(f"{'='*60}")
m = 128; s_k = 128
use_smem_p = hd > 64
pv_n_tile = min(hd, 256)
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(s_k, pv_n_tile, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
# Use FmhaKernel's actual __call__ path to compute layouts
from dsv4.kernels.attention.fmha import FmhaKernel
kern = FmhaKernel(head_dim=hd, s_k=s_k)
# Trigger _setup by calling compile with a dummy kernel
# Actually, we can access the _setup by calling __call__ on the kernel
# But that runs the full kernel. Let me manually call _setup by
# building the MMAs the same way __call__ does.
a_major = LayoutEnum.from_tensor(ct.from_dlpack(q)).mma_major_mode()
b_major = LayoutEnum.from_tensor(ct.from_dlpack(k)).mma_major_mode()
# Create a minimal analysis kernel that just prints shapes
@cute.kernel
def _diag_kernel(_q, _k, _v, _c):
# Build QK MMA
qk_mma = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, a_major, b_major, Float32,
tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM,
)
pv_a_major = a_major if use_smem_p else cute.nvgpu.OperandMajorMode.K
pv_source = tcgen05.OperandSource.SMEM if use_smem_p else tcgen05.OperandSource.TMEM
c_major = LayoutEnum.from_tensor(ct.from_dlpack(c)).mma_major_mode()
pv_mma = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, pv_a_major, c_major, Float32,
tcgen05.CtaGroup.ONE, (128, pv_n_tile), pv_source,
)
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
qk_mma_tiler = (128, 128, qk_ik * 4)
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
pv_mma_tiler = (128, pv_n_tile, pv_ik * (128 // pv_ik))
print(f" QK MMA shape_mnk: {qk_mma.shape_mnk}")
print(f" PV MMA shape_mnk: {pv_mma.shape_mnk}")
print(f" QK tiler: {qk_mma_tiler}")
print(f" PV tiler: {pv_mma_tiler}")
print(f" use_smem_p: {use_smem_p}")
# QK C-fragment
qk_thr = qk_mma.get_slice(0)
qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
print(f" tStS shape: {cute.shape(tStS)} layout: {tStS.layout}")
# PV C-fragment
pv_thr = pv_mma.get_slice(0)
pv_as = pv_thr.partition_shape_C(pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_as)
print(f" tOtO shape: {cute.shape(tOtO)} layout: {tOtO.layout}")
# PV A-operand SMEM layout (sP)
p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
# p_smem_s.outer is the layout. Use it directly for shape queries.
# Can't actually allocate SMEM here (that's the SmemAllocator's job in the kernel)
# But we can create a tensor with a dummy pointer for shape analysis.
sP_alloc = cute.make_tensor(0, p_smem_s.outer)
print(f" sP shape: {cute.shape(sP_alloc)}")
sP_2d = cute.group_modes(sP_alloc, 0, 3)
print(f" sP_2d shape: {cute.shape(sP_2d)} rank={len(cute.shape(sP_2d))}")
print(f" sP_2d layout: {sP_2d.layout}")
print(f" sP_2d layout: {sP_2d.layout}")
# PV A-operand fragments — can't create with dummy ptr 0 (no memspace)
# Just print the layout info from the SMEM layout
print(f" p_smem_s.outer (sP layout): {p_smem_s.outer}")
print(f" p_smem_s.inner (swizzle): {p_smem_s.inner}")
# What we need: the PV MMA's A-operand thread partition for sP
# This tells us which threads read which sP elements during PV GEMM
# The softmax warps must WRITE to sP using the same mapping
# (so that the MMA warp can READ using its own partition)
# Softmax TMEM load partition
sfw_idx = 0
tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS)
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS)
cS = cute.make_identity_tensor((128, 128))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
print(f" tTMEM_LOADcS (coord) shape: {cute.shape(tTMEM_LOADcS)} layout: {tTMEM_LOADcS.layout}")
# P sub-layout in TMEM
p_cols_fp32 = pv_mma_tiler[2] * BFloat16.width // Float32.width
print(f" p_cols_fp32: {p_cols_fp32}")
# make_tiled_copy_C with qk_mma
print(f"\n --- make_tiled_copy_C(qk_mma) ---")
try:
smem_copy_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), BFloat16, num_bits_per_copy=128)
tiled_copy = cute.make_tiled_copy_C(smem_copy_atom, qk_mma)
thr_copy = tiled_copy.get_slice(sfw_idx)
print(f" OK, created tiled_copy")
tD = thr_copy.partition_D(sP_2d)
print(f" partition_D(sP_2d) shape: {cute.shape(tD)} rank={len(cute.shape(tD))}")
# Create source tensor using the QK C-fragment register layout
# This is what make_tiled_copy_C expects as source
qk_C_reg = qk_thr.make_fragment_C(qk_as) # register fragment
# But we need a 2D version (without stage mode)
qk_C_2d = cute.group_modes(qk_C_reg, 0, 2)
print(f" qk_C_reg shape: {cute.shape(qk_C_reg)} layout: {qk_C_reg.layout}")
print(f" qk_C_2d shape: {cute.shape(qk_C_2d)} layout: {qk_C_2d.layout}")
try:
tS = thr_copy.partition_S(qk_C_2d)
print(f" partition_S(qk_C_2d) shape: {cute.shape(tS)} rank={len(cute.shape(tS))}")
print(f" partition_S and partition_D RANKS MATCH: {len(cute.shape(tS)) == len(cute.shape(tD))}")
except Exception as e:
print(f" partition_S(qk_C_2d) FAILED: {e}")
# Try with a P-sized sub-tensor of the QK C-fragment
# P is only the first p_cols_fp32 columns of S
p_cols_bf16 = pv_mma_tiler[2] # K dimension of PV = columns of P
print(f" p_cols_bf16 (PV tiler[2]): {p_cols_bf16}")
qk_P_shape = (128, p_cols_bf16) # P matrix is (128, p_cols_bf16)
qk_P_layout = cute.composition(qk_C_reg.layout, cute.make_layout(qk_P_shape))
qk_P = cute.make_tensor(0, qk_P_layout)
qk_P_2d = cute.group_modes(qk_P, 0, 2)
print(f" qk_P_2d shape: {cute.shape(qk_P_2d)} layout: {qk_P_2d.layout}")
try:
tS = thr_copy.partition_S(qk_P_2d)
print(f" partition_S(qk_P_2d) shape: {cute.shape(tS)} rank={len(cute.shape(tS))}")
print(f" RANKS MATCH: {len(cute.shape(tS)) == len(cute.shape(tD))}")
except Exception as e2:
print(f" partition_S(qk_P_2d) FAILED: {e2}")
except Exception as e:
print(f" make_tiled_copy_C FAILED: {e}")
# make_tiled_copy_C with pv_mma
print(f"\n --- make_tiled_copy_C(pv_mma) ---")
try:
tiled_copy_pv = cute.make_tiled_copy_C(smem_copy_atom, pv_mma)
thr_copy_pv = tiled_copy_pv.get_slice(sfw_idx)
tD_pv = thr_copy_pv.partition_D(sP_2d)
print(f" OK, partition_D(sP_2d) shape: {cute.shape(tD_pv)} rank={len(cute.shape(tD_pv))}")
except Exception as e:
print(f" make_tiled_copy_C(pv_mma) FAILED: {e}")
# KEY QUESTION: what is the relationship between QK C-fragment registers
# and the TMEM load partition? Are they the same threads, same values?
print(f"\n --- QK C-fragment vs TMEM load partition ---")
print(f" tStS shape: {cute.shape(tStS)} layout: {tStS.layout}")
print(f" tTMEM_LOADtS shape: {cute.shape(tTMEM_LOADtS)}")
print(f" tTMEM_LOADcS shape: {cute.shape(tTMEM_LOADcS)} layout: {tTMEM_LOADcS.layout}")
print(f" size(tStS) = {cute.size(tStS)}, size(tTMEM_LOADcS) = {cute.size(tTMEM_LOADcS)}")
# Do the C-fragment and TMEM load have the same per-thread element count?
# If so, the same threads own the same elements in both views.
# Compile and run
import cuda.bindings.driver as cuda
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v.unsqueeze(-1)).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v.unsqueeze(-1)))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
print(' Compiling diagnostic...', flush=True)
compiled = cute.compile(_diag_kernel, mQ, mK, mV, mC)
compiled(mQ, mK, mV, mC)
torch.cuda.synchronize()
if __name__ == '__main__':
main()