""" Minimal PV-only test: Load P from GMEM to TMEM via QK-style MMA, then PV from TMEM. Step 1: QK MMA writes FP32 S to TMEM (we know this works) Step 2: Softmax packing writes BF16 P to TMEM (test this) Step 3: PV MMA reads BF16 P from TMEM and V from SMEM, produces O But to isolate the bug, let me test just the PV MMA in isolation. I'll write known BF16 values to TMEM using the softmax packing path, then immediately read them back using the PV A-fragment path, and compare. Actually, the simplest isolation test: 1. Do QK MMA to get S in TMEM (cosine 0.999999 verified) 2. Do softmax packing: S → P in TMEM (at offset 32) 3. Skip PV entirely — read P from TMEM using the C-fragment composition LOAD path 4. Output P to GMEM and compare against S.to(BF16) This tests whether the softmax packing writes P correctly to the same TMEM that the PV would read from. But we can't easily read P from TMEM using the standard epilogue path because the epilogue expects FP32 accumulator data. Alternative: Use the PV MMA with V=I (identity). If P is correct, then P @ I = P. But V needs to be MN-major and (128, 128), not (128, 64). The output would be (128, 128) which doesn't match our (128, 64) c tensor. Let me use V that selects the first 64 columns: V[k, n] = delta(k, n) for k in [0,63]. This gives P @ V = P[:, :64], and the output is (128, 64). But V is (128, 128) in the MMA K,N dims. V[k, n] for k in [0,127], n in [0,63]. Hmm, this is getting complicated. Let me just do the identity approach with a (128, 128) output. """ import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline from cutlass.cute.nvgpu import cpasync, tcgen05 from cutlass import Float32, BFloat16, Int32, Boolean, const_expr from cutlass.utils import LayoutEnum from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset import cuda.bindings.driver as cuda import cutlass.torch as ct class Test128x16Tiler: """QK + softmax packing + PV with V=I to isolate PV MMA correctness. Output should be P = S.to(BF16), i.e. (Q@K^T).bfloat16() With V=I, O = P @ I = P. But V is (K=128, N=128) in the MMA. We need a 128x128 identity in MN-major. Output tensor is (128, 128). """ def __init__(self, mma_tiler_mn): self.acc_dtype = Float32; self.qk_acc_dtype = Float32 self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16 self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1) self.use_2cta_instrs = False # needed by epilogue_tma_store self.epilog_sync_bar_id = 1 # needed by epilogue_tma_store self.cluster_shape_mn = (1, 1) self.cta_group = tcgen05.CtaGroup.ONE self.epilogue_warp_id = (0, 1, 2, 3) self.mma_warp_id = 4; self.tma_warp_id = 5 self.threads_per_cta = 192 self.num_c_stage = 2 def _setup(self, qk_mma, pv_mma): qk_inst_k = int(cute.size(qk_mma.shape_mnk, mode=[2])) self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4) # PV with V=I: output is (128, 128), same as QK self.pv_mma_tiler = (self.qk_mma_tiler[0], qk_inst_k, self.qk_mma_tiler[1]) # pv_mma_tiler = (128, 128, 128) since V is 128x128 self.mma_tiler = self.qk_mma_tiler self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) self.cta_tile_shape_mnk = ( self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape), self.qk_mma_tiler[1], self.qk_mma_tiler[2]) self.c_layout = LayoutEnum.ROW_MAJOR self.epi_tile = utils.sm100.compute_epilogue_tile_shape( (self.pv_mma_tiler[0], self.pv_mma_tiler[1], self.pv_mma_tiler[2]), False, self.c_layout, self.o_dtype) self.num_ab_stage = 1; self.num_acc_stage = 1 self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1) self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1) self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1) self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2) qk_thr = qk_mma.get_slice(0) qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) tStS = qk_thr.make_fragment_C(qk_acc_shape) s_cols = find_tmem_tensor_col_offset(tStS) pv_thr = pv_mma.get_slice(0) pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) tOtO = pv_thr.make_fragment_C(pv_acc_shape) o_cols = find_tmem_tensor_col_offset(tOtO) self.tilePlikeFP32 = self.qk_mma_tiler[1] // Float32.width * self.o_dtype.width print(f"tilePlikeFP32={self.tilePlikeFP32}, pv_mma_tiler={self.pv_mma_tiler}, qk_mma_tiler={self.qk_mma_tiler}") self.tmem_s0_offset = 0 self.tmem_p0_offset = 32 self.tmem_o0_offset = s_cols tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100") # ⛔⛔⛔ CRITICAL: num_tma_load_bytes MUST include ALL TMA-loaded tensors (Q + K + V). Missing V → DEADLOCK. See FOOTGUN #0 in README. a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0)) self.num_tma_load_bytes = ( cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem) + cute.size_in_bytes(self.q_dtype, v_smem) ) * cute.size(qk_mma.thr_id.shape) @cute.jit def __call__(self, q, k, v, c, stream): self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype self.a_major = LayoutEnum.from_tensor(q).mma_major_mode() self.b_major = LayoutEnum.from_tensor(k).mma_major_mode() self.v_major = LayoutEnum.from_tensor(v).mma_major_mode() self.c_layout = LayoutEnum.from_tensor(c) qk_mma = utils.sm100.make_trivial_tiled_mma( self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM) # PV with 128x128 output (V=I) pv_mma = utils.sm100.make_trivial_tiled_mma( self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major, self.qk_acc_dtype, self.cta_group, (128, 16), tcgen05.OperandSource.TMEM) self._setup(qk_mma, pv_mma) q_smem = cute.slice_(self.a_smem_s, (None, None, None, 0)) k_smem = cute.slice_(self.b_smem_s, (None, None, None, 0)) v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0)) tma_q, tma_tq = cute.nvgpu.make_tiled_tma_atom_A( utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id), q, q_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) tma_k, tma_tk = cute.nvgpu.make_tiled_tma_atom_B( utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id), k, k_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape) tma_v, tma_tv = cute.nvgpu.make_tiled_tma_atom_B( utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id), v, v_smem, self.pv_mma_tiler, pv_mma, self.cluster_layout_vmnk.shape) epi_smem = cute.select(self.c_smem_s, mode=[0, 1]) tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile) self._kernel(qk_mma, pv_mma, tma_q, tma_tq, tma_k, tma_tk, tma_v, tma_tv, tma_c, tma_tc, self.cluster_layout_vmnk, self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile ).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream) @cute.kernel def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV, tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile): warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) tidx, _, _ = cute.arch.thread_idx() use_2cta = cute.size(qk_mma.thr_id.shape) == 2 if warp_idx == self.tma_warp_id: cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k) cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c) @cute.struct class SS: ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2] acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] tmem_dealloc: cutlass.Int64 holding: cutlass.Int32 smem = utils.SmemAllocator(); st = smem.allocate(SS) ab_p, ab_c = pipeline.PipelineTmaUmma.create( barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True ).make_participants() mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create( barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)), ).make_participants() acc_pipe = pipeline.PipelineUmmaAsync.create( barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage, producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), consumer_group=pipeline.CooperativeGroup( pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)), cta_layout_vmnk=cl_vmnk, defer_sync=True) tmem_bar = pipeline.NamedBarrier(barrier_id=2, num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id))) tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar, allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta, two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr) pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True) sQ = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner) sK = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner) sV = smem.allocate_tensor(element_type=self.q_dtype, layout=v_smem_s.outer, byte_alignment=128, swizzle=v_smem_s.inner) sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner) gQ = cute.local_tile(mQ, cute.slice_(self.qk_mma_tiler, (None,0,None)), (None,None,None)) gK = cute.local_tile(mK, cute.slice_(self.qk_mma_tiler, (0,None,None)), (None,None,None)) gC = cute.local_tile(mC, cute.slice_(self.qk_mma_tiler, (None,None,0)), (None,None,None)) k_cnt = cute.size(gQ, mode=[3]) qk_thr = qk_mma.get_slice(0) pv_thr = pv_mma.get_slice(0) tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK); tCgC = qk_thr.partition_C(gC) a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape) tAsQ, tAgQ = cpasync.tma_partition(tma_q, 0, a_lay, cute.group_modes(sQ,0,3), cute.group_modes(tCgQ,0,3)) b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape) tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay, cute.group_modes(sK,0,3), cute.group_modes(tCgK,0,3)) tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)] gV = cute.local_tile(mV, cute.slice_(self.pv_mma_tiler, (0,None,None)), (None,None,None)) tCgV = pv_thr.partition_B(gV) tVsV, tVgV = cpasync.tma_partition(tma_v, 0, b_lay, cute.group_modes(sV,0,3), cute.group_modes(tCgV,0,3)) tVgV = tVgV[(None,0,None,0)] tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) tCrV = pv_mma.make_fragment_B(sV) qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) tStS = qk_thr.make_fragment_C(qk_acc_shape) tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout) pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2]) tOtO = pv_thr.make_fragment_C(pv_acc_shape) tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout) tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer) tOrP_base = pv_thr.make_fragment_A(tP) tOrP = tOrP_base[(None, None, None, 0)] tOrP0 = cute.make_tensor( tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset, tOrP.layout) tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage)) tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage)) pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk) # ═══ TMA LOAD WARP ═══ if warp_idx == self.tma_warp_id: ab_p.reset(); peek = ab_p.try_acquire() for kt in cutlass.range(k_cnt, unroll=1): h = ab_p.acquire_and_advance(peek) cute.copy(tma_q, tAgQ[(None,h.count)], tAsQ[(None,h.index)], tma_bar_ptr=h.barrier) cute.copy(tma_k, tBgK[(None,h.count)], tBsK[(None,h.index)], tma_bar_ptr=h.barrier) cute.copy(tma_v, tVgV[(None,h.count)], tVsV[(None,h.index)], tma_bar_ptr=h.barrier) peek = cutlass.Boolean(1) if h.count+1= 0.99 else 'FAIL')) if __name__ == '__main__': test()