- Bug 1 (V MN-major): Fix applied - Bug 2 (softmax packing): Confirmed correct (V=I test: cosine 1.0) - Bug 3 (ACCUMULATE): Fix applied (first PV must overwrite, not accumulate) - Bug 4 (CURRENT): PV MMA broken for non-square output - (128,128) PV with random V: cosine 0.999999 ✅ - (128,64) PV with MN-major V: cosine ~0.01 ❌ - Softmax packing, layout aliasing, pipeline ordering all verified correct - Root cause unknown — likely epilogue/V layout/MMA tiler issue Added test_pv_diag.py (V=I and random V, 128x128 output — PASS) Added test_layout_compare.py (TMEM layout inspection) Added test_inspect_types.py (TMEM pointer arithmetic verification) Updated test_mma_si_pv.py with head_dim param, pv_mma_tiler_mn fix, ACCUMULATE fix Updated READMEs with current state
96 lines
4.3 KiB
Python
96 lines
4.3 KiB
Python
"""Compare C-fragment composition layout vs A-fragment layout for PV P operand."""
|
|
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils
|
|
from cutlass.cute.nvgpu import tcgen05
|
|
from cutlass import Float32, BFloat16
|
|
from cutlass.utils import LayoutEnum
|
|
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
|
|
import cuda.bindings.driver as cuda
|
|
import cutlass.torch as ct
|
|
|
|
|
|
class LayoutCompareKernel:
|
|
def __init__(self):
|
|
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
|
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
|
|
self.mma_tiler_mn = (128, 128)
|
|
self.cta_group = tcgen05.CtaGroup.ONE
|
|
self.threads_per_cta = 64 # minimal
|
|
|
|
@cute.jit
|
|
def __call__(self, q, k, v, c, stream):
|
|
a_major = LayoutEnum.from_tensor(q).mma_major_mode()
|
|
b_major = LayoutEnum.from_tensor(k).mma_major_mode()
|
|
v_major = LayoutEnum.from_tensor(v).mma_major_mode()
|
|
|
|
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
|
self.q_dtype, self.q_dtype, a_major, b_major,
|
|
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
|
|
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
|
self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, v_major,
|
|
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
|
|
|
|
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
|
|
qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
|
|
pv_mma_tiler = (qk_mma_tiler[0], qk_mma_tiler[2], qk_mma_tiler[1])
|
|
|
|
qk_thr = qk_mma.get_slice(0)
|
|
pv_thr = pv_mma.get_slice(0)
|
|
|
|
qk_acc_shape = qk_thr.partition_shape_C(qk_mma_tiler[:2])
|
|
tStS = qk_thr.make_fragment_C(qk_acc_shape)
|
|
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
|
|
|
|
pv_acc_shape = pv_thr.partition_shape_C(pv_mma_tiler[:2])
|
|
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
|
|
|
|
# P A-fragment
|
|
p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, self.q_dtype, 1)
|
|
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
|
tOrP_base = pv_thr.make_fragment_A(tP)
|
|
tOrP = tOrP_base[(None, None, None, 0)]
|
|
|
|
# C-fragment composition layout
|
|
tilePlikeFP32 = qk_mma_tiler[1] // Float32.width * self.o_dtype.width
|
|
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32)))
|
|
tStS_P = cute.make_tensor(tStS.iterator + 32, tStS_P_layout) # offset 32 FP32 columns
|
|
|
|
# With scaled offset for A-fragment
|
|
p_offset_in_a_elements = self.qk_acc_dtype.width // self.q_dtype.width * 32 # = 64
|
|
tOrP0 = cute.make_tensor(tOrP.iterator + p_offset_in_a_elements, tOrP.layout)
|
|
|
|
# Print layouts
|
|
cute.printf("tStS layout: {}", tStS.layout)
|
|
cute.printf("tOrP layout: {}", tOrP.layout)
|
|
cute.printf("tStS_P layout: {}", tStS_P_layout)
|
|
cute.printf("tOrP0 layout: {}", tOrP0.layout)
|
|
cute.printf("tOrP shape: {}", tOrP.shape)
|
|
cute.printf("tStS_P shape: {}", tStS_P.shape)
|
|
cute.printf("tOtO layout: {}", tOtO.layout)
|
|
cute.printf("pv_mma_tiler: {}", pv_mma_tiler)
|
|
cute.printf("qk_mma_tiler: {}", qk_mma_tiler)
|
|
|
|
|
|
def test():
|
|
m, n, head_dim = 128, 128, 64
|
|
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
|
k = torch.randn(n, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
|
v_base = torch.randn(head_dim, n, dtype=torch.bfloat16, device='cuda')
|
|
v = v_base.as_strided((head_dim, n), (1, head_dim)).unsqueeze(-1)
|
|
c = torch.zeros(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
|
|
|
|
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
|
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
|
mV = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v))
|
|
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
|
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
|
kernel = LayoutCompareKernel()
|
|
print('Compiling...', flush=True)
|
|
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
|
|
print('Running...', flush=True)
|
|
compiled(mQ, mK, mV, mC, stream)
|
|
torch.cuda.synchronize()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test()
|