Files
nvfp4-megamoe-kernel/tests/test_layout_compare.py
biondizzle 0dc6fe4a7d Stage B progress: PV works for square (128,128), broken for (128,64)
- Bug 1 (V MN-major): Fix applied
- Bug 2 (softmax packing): Confirmed correct (V=I test: cosine 1.0)
- Bug 3 (ACCUMULATE): Fix applied (first PV must overwrite, not accumulate)
- Bug 4 (CURRENT): PV MMA broken for non-square output
  - (128,128) PV with random V: cosine 0.999999 
  - (128,64) PV with MN-major V: cosine ~0.01 
  - Softmax packing, layout aliasing, pipeline ordering all verified correct
  - Root cause unknown — likely epilogue/V layout/MMA tiler issue

Added test_pv_diag.py (V=I and random V, 128x128 output — PASS)
Added test_layout_compare.py (TMEM layout inspection)
Added test_inspect_types.py (TMEM pointer arithmetic verification)
Updated test_mma_si_pv.py with head_dim param, pv_mma_tiler_mn fix, ACCUMULATE fix
Updated READMEs with current state
2026-05-21 04:40:28 +00:00

96 lines
4.3 KiB
Python

"""Compare C-fragment composition layout vs A-fragment layout for PV P operand."""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils
from cutlass.cute.nvgpu import tcgen05
from cutlass import Float32, BFloat16
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
import cutlass.torch as ct
class LayoutCompareKernel:
def __init__(self):
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
self.mma_tiler_mn = (128, 128)
self.cta_group = tcgen05.CtaGroup.ONE
self.threads_per_cta = 64 # minimal
@cute.jit
def __call__(self, q, k, v, c, stream):
a_major = LayoutEnum.from_tensor(q).mma_major_mode()
b_major = LayoutEnum.from_tensor(k).mma_major_mode()
v_major = LayoutEnum.from_tensor(v).mma_major_mode()
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.q_dtype, self.q_dtype, a_major, b_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(
self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, v_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
pv_mma_tiler = (qk_mma_tiler[0], qk_mma_tiler[2], qk_mma_tiler[1])
qk_thr = qk_mma.get_slice(0)
pv_thr = pv_mma.get_slice(0)
qk_acc_shape = qk_thr.partition_shape_C(qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
pv_acc_shape = pv_thr.partition_shape_C(pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
# P A-fragment
p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, self.q_dtype, 1)
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP_base = pv_thr.make_fragment_A(tP)
tOrP = tOrP_base[(None, None, None, 0)]
# C-fragment composition layout
tilePlikeFP32 = qk_mma_tiler[1] // Float32.width * self.o_dtype.width
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
tStS_P = cute.make_tensor(tStS.iterator + 32, tStS_P_layout) # offset 32 FP32 columns
# With scaled offset for A-fragment
p_offset_in_a_elements = self.qk_acc_dtype.width // self.q_dtype.width * 32 # = 64
tOrP0 = cute.make_tensor(tOrP.iterator + p_offset_in_a_elements, tOrP.layout)
# Print layouts
cute.printf("tStS layout: {}", tStS.layout)
cute.printf("tOrP layout: {}", tOrP.layout)
cute.printf("tStS_P layout: {}", tStS_P_layout)
cute.printf("tOrP0 layout: {}", tOrP0.layout)
cute.printf("tOrP shape: {}", tOrP.shape)
cute.printf("tStS_P shape: {}", tStS_P.shape)
cute.printf("tOtO layout: {}", tOtO.layout)
cute.printf("pv_mma_tiler: {}", pv_mma_tiler)
cute.printf("qk_mma_tiler: {}", qk_mma_tiler)
def test():
m, n, head_dim = 128, 128, 64
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(n, head_dim, 1, dtype=torch.bfloat16, device='cuda')
v_base = torch.randn(head_dim, n, dtype=torch.bfloat16, device='cuda')
v = v_base.as_strided((head_dim, n), (1, head_dim)).unsqueeze(-1)
c = torch.zeros(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = LayoutCompareKernel()
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mV, mC, stream)
torch.cuda.synchronize()
if __name__ == '__main__':
test()