"""Compare C-fragment composition layout vs A-fragment layout for PV P operand.""" import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils from cutlass.cute.nvgpu import tcgen05 from cutlass import Float32, BFloat16 from cutlass.utils import LayoutEnum from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset import cuda.bindings.driver as cuda import cutlass.torch as ct class LayoutCompareKernel: def __init__(self): self.acc_dtype = Float32; self.qk_acc_dtype = Float32 self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16 self.mma_tiler_mn = (128, 128) self.cta_group = tcgen05.CtaGroup.ONE self.threads_per_cta = 64 # minimal @cute.jit def __call__(self, q, k, v, c, stream): a_major = LayoutEnum.from_tensor(q).mma_major_mode() b_major = LayoutEnum.from_tensor(k).mma_major_mode() v_major = LayoutEnum.from_tensor(v).mma_major_mode() qk_mma = utils.sm100.make_trivial_tiled_mma( self.q_dtype, self.q_dtype, a_major, b_major, self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM) pv_mma = utils.sm100.make_trivial_tiled_mma( self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, v_major, self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM) qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2]) qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) pv_mma_tiler = (qk_mma_tiler[0], qk_mma_tiler[2], qk_mma_tiler[1]) qk_thr = qk_mma.get_slice(0) pv_thr = pv_mma.get_slice(0) qk_acc_shape = qk_thr.partition_shape_C(qk_mma_tiler[:2]) tStS = qk_thr.make_fragment_C(qk_acc_shape) tStS0 = cute.make_tensor(tStS.iterator, tStS.layout) pv_acc_shape = pv_thr.partition_shape_C(pv_mma_tiler[:2]) tOtO = pv_thr.make_fragment_C(pv_acc_shape) # P A-fragment p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, self.q_dtype, 1) tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) tOrP_base = pv_thr.make_fragment_A(tP) tOrP = tOrP_base[(None, None, None, 0)] # C-fragment composition layout tilePlikeFP32 = qk_mma_tiler[1] // Float32.width * self.o_dtype.width tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, tilePlikeFP32))) tStS_P = cute.make_tensor(tStS.iterator + 32, tStS_P_layout) # offset 32 FP32 columns # With scaled offset for A-fragment p_offset_in_a_elements = self.qk_acc_dtype.width // self.q_dtype.width * 32 # = 64 tOrP0 = cute.make_tensor(tOrP.iterator + p_offset_in_a_elements, tOrP.layout) # Print layouts cute.printf("tStS layout: {}", tStS.layout) cute.printf("tOrP layout: {}", tOrP.layout) cute.printf("tStS_P layout: {}", tStS_P_layout) cute.printf("tOrP0 layout: {}", tOrP0.layout) cute.printf("tOrP shape: {}", tOrP.shape) cute.printf("tStS_P shape: {}", tStS_P.shape) cute.printf("tOtO layout: {}", tOtO.layout) cute.printf("pv_mma_tiler: {}", pv_mma_tiler) cute.printf("qk_mma_tiler: {}", qk_mma_tiler) def test(): m, n, head_dim = 128, 128, 64 q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda') k = torch.randn(n, head_dim, 1, dtype=torch.bfloat16, device='cuda') v_base = torch.randn(head_dim, n, dtype=torch.bfloat16, device='cuda') v = v_base.as_strided((head_dim, n), (1, head_dim)).unsqueeze(-1) c = torch.zeros(m, head_dim, 1, dtype=torch.bfloat16, device='cuda') mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) mV = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v)) mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) kernel = LayoutCompareKernel() print('Compiling...', flush=True) compiled = cute.compile(kernel, mQ, mK, mV, mC, stream) print('Running...', flush=True) compiled(mQ, mK, mV, mC, stream) torch.cuda.synchronize() if __name__ == '__main__': test()