From 625837fd447a27a4bcfe569ca3d1008a2b2d0fd9 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sun, 24 May 2026 14:19:26 +0000 Subject: [PATCH] D1.4: Add hd=512 QK-only and standalone test for compilation debugging --- tests/unit/test_d1_hd512_only.py | 87 ++++++++++++++++ tests/unit/test_d1_qk512.py | 164 +++++++++++++++++++++++++++++++ 2 files changed, 251 insertions(+) create mode 100644 tests/unit/test_d1_hd512_only.py create mode 100644 tests/unit/test_d1_qk512.py diff --git a/tests/unit/test_d1_hd512_only.py b/tests/unit/test_d1_hd512_only.py new file mode 100644 index 00000000..62599cdb --- /dev/null +++ b/tests/unit/test_d1_hd512_only.py @@ -0,0 +1,87 @@ +"""D1 test: HEAD_DIM=512 only (faster iteration on compilation issues).""" +import torch, math, sys +import cutlass.cute as cute +import cutlass.torch as ct +import cuda.bindings.driver as cuda +from dsv4.kernels.attention.fmha import FmhaKernel + + +def test(): + torch.manual_seed(42) + hd, n = 512, 128 + m = 128 + q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') + k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device='cuda') + v = torch.randn(n, hd, dtype=torch.bfloat16, device='cuda') + + qf = q[:, :, 0].float() + kf = k[:, :, 0].float() + scale = 1.0 / math.sqrt(hd) + attn_max = (qf @ kf.T * scale).max(dim=-1, keepdim=True)[0] + attn_exp = torch.exp(qf @ kf.T * scale - attn_max) + attn_sum = attn_exp.sum(dim=-1, keepdim=True) + ref_unnorm = attn_exp @ v.float() + ref_lse = (torch.log(attn_sum.squeeze(-1)) + attn_max.squeeze(-1))[0].item() + + lse_tensor = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda') + kernel = FmhaKernel(head_dim=hd, s_k=n, use_smem_p=False) + pv_n_tile = kernel.pv_n_tile + n_pv_tiles = kernel.n_pv_tiles + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + print(f'hd={hd}, pv_n_tile={pv_n_tile}, n_pv_tiles={n_pv_tiles}, n_k_sub_tiles={kernel.n_k_sub_tiles}, k_tile={kernel.k_tile}', flush=True) + print(f'Compiling first PV tile...', flush=True) + + # Only compile the first PV tile to isolate compilation issues + v_tile = v[:, 0:pv_n_tile].contiguous() + v_kernel = v_tile.unsqueeze(-1) + c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') + + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) + mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile)) + mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor)) + + import time + t0 = time.time() + compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE) + t1 = time.time() + print(f'Compilation took {t1-t0:.1f}s', flush=True) + + # Run all PV tiles + lse_val = None + c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda') + for nt in range(n_pv_tiles): + v_start = nt * pv_n_tile + v_end = v_start + pv_n_tile + v_tile = v[:, v_start:v_end].contiguous() + v_kernel = v_tile.unsqueeze(-1) + c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') + lse_tensor.zero_() + + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) + mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile)) + mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor)) + + compiled(mQ, mK, mV, mC, stream, mLSE) + torch.cuda.synchronize() + print(f' PV tile {nt}: done', flush=True) + + c[:, v_start:v_end, :] = c_tile + if nt == 0: + lse_val = lse_tensor[0, 0, 0].item() + + out_unnorm = c[:, :, 0].float() + cos_unnorm = torch.nn.functional.cosine_similarity( + out_unnorm.flatten().unsqueeze(0), ref_unnorm.flatten().unsqueeze(0) + ).item() + lse_err = abs(lse_val - ref_lse) if lse_val is not None else float('inf') + + status = "PASS" if cos_unnorm >= 0.99 else "FAIL" + print(f'hd={hd}: cos_unnorm {cos_unnorm:.6f} lse_err {lse_err:.6f} {status}') + +if __name__ == '__main__': + test() diff --git a/tests/unit/test_d1_qk512.py b/tests/unit/test_d1_qk512.py new file mode 100644 index 00000000..581187ee --- /dev/null +++ b/tests/unit/test_d1_qk512.py @@ -0,0 +1,164 @@ +"""Minimal hd=512 test: ONLY QK GEMM, no softmax, no PV. +Goal: isolate whether the compilation hang is from QK or softmax/PV.""" +import torch, math, time +import cutlass.cute as cute +import cutlass.torch as ct +import cuda.bindings.driver as cuda +from cutlass import BFloat16, Float32 +from cutlass.cute.nvgpu import tcgen05 +from cutlass.utils import LayoutEnum +import cutlass.utils as utils +import cutlass.pipeline as pipeline +from cutlass import const_expr + +class QkOnly512: + def __init__(self): + self.head_dim = 512 + self.k_tile = 256 + self.n_k_sub_tiles = 2 + self.kv_stage = 1 + self.q_stage = 1 + self.q_dtype = BFloat16 + self.qk_acc_dtype = Float32 + self.cta_group = tcgen05.CtaGroup.ONE + self.cluster_shape_mn = (1, 1) + self.qk_mma_tiler = (128, 128, self.k_tile) + self.threads_per_cta = 192 + self.mma_warp_id = 4 + self.tma_warp_id = 5 + self.epilogue_warp_id = (0,1,2,3) + + def _setup(self, qk_mma): + self.q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.q_dtype, self.q_stage) + self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage) + cta = cute.size(qk_mma.thr_id.shape) + q_s = cute.slice_(self.q_smem_s,(None,None,None,0)) + k_s = cute.slice_(self.k_smem_s,(None,None,None,0)) + self.q_tx_bytes = cute.size_in_bytes(self.q_dtype, q_s) * cta + self.kv_tx_bytes = cute.size_in_bytes(self.q_dtype, k_s) * cta + + @cute.jit + def __call__(self, q, k, s_out, stream): + self.a_major = LayoutEnum.from_tensor(q).mma_major_mode() + self.b_major = LayoutEnum.from_tensor(k).mma_major_mode() + qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM) + self._setup(qk_mma) + q_s = cute.slice_(self.q_smem_s,(None,None,None,0)) + k_s = cute.slice_(self.k_smem_s,(None,None,None,0)) + tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_shape_mn) + tma_k,mK = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,qk_mma.thr_id),k,k_s,self.qk_mma_tiler,qk_mma,self.cluster_shape_mn) + self._kernel(qk_mma, tma_q, mQ, tma_k, mK, s_out).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream) + + @cute.kernel + def _kernel(self, qk_mma, tma_q, mQ, tma_k, mK, s_out): + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx,_,_ = cute.arch.thread_idx() + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k) + + @cute.struct + class SS: + q_bar: cute.struct.MemRange[cutlass.Int64, self.q_stage*2] + kv_bar: cute.struct.MemRange[cutlass.Int64, self.kv_stage*2] + holding: cutlass.Int32 + smem = utils.SmemAllocator(); st = smem.allocate(SS) + + qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)),defer_sync=True).make_participants() + kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)),defer_sync=True).make_participants() + + tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id))) + tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=self.epilogue_warp_id[0],is_two_cta=cute.size(qk_mma.thr_id.shape)==2) + pipeline.pipeline_init_arrive(cluster_shape_mn=cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)),is_relaxed=True) + + sQ = smem.allocate_tensor(element_type=self.q_dtype,layout=self.q_smem_s.outer,byte_alignment=128,swizzle=self.q_smem_s.inner) + sK = smem.allocate_tensor(element_type=self.q_dtype,layout=self.k_smem_s.outer,byte_alignment=128,swizzle=self.k_smem_s.inner) + + gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None)) + gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None)) + + qk_thr = qk_mma.get_slice(0) + tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK) + a_lay = cute.make_layout((1,)) + b_lay = cute.make_layout((1,)) + tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(sQ,0,3),cute.group_modes(tCgQ,0,3)) + tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3)) + tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)] + + tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK) + qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2]) + tStS = qk_thr.make_fragment_C(qk_as) + tStS0 = cute.make_tensor(tStS.iterator, tStS.layout) + + pipeline.pipeline_init_wait(cluster_shape_mn=cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))) + + # ===== TMA LOAD warp ===== + if warp_idx == self.tma_warp_id: + qp.reset() + kvp.reset() + # k_sub=0 + qh0 = qp.acquire_and_advance() + cute.copy(tma_q, tAgQ[(None, cutlass.Int32(0))], tAsQ[(None, qh0.index)], tma_bar_ptr=qh0.barrier) + kvh0 = kvp.acquire_and_advance() + cute.copy(tma_k, tBgK[(None, cutlass.Int32(0))], tBsK[(None, kvh0.index)], tma_bar_ptr=kvh0.barrier) + # k_sub=1 + qh1 = qp.acquire_and_advance() + cute.copy(tma_q, tAgQ[(None, cutlass.Int32(1))], tAsQ[(None, qh1.index)], tma_bar_ptr=qh1.barrier) + kvh1 = kvp.acquire_and_advance() + cute.copy(tma_k, tBgK[(None, cutlass.Int32(1))], tBsK[(None, kvh1.index)], tma_bar_ptr=kvh1.barrier) + qp.tail() + kvp.tail() + + # ===== MMA warp ===== + if warp_idx == self.mma_warp_id: + tmem.wait_for_alloc() + # k_sub=0 + qh0 = qc.wait_and_advance(); qh0.release() + kvh0 = kvc.wait_and_advance() + qk_mma.set(tcgen05.Field.ACCUMULATE, False) + for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True): + cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh0.index)], tStS0) + qk_mma.set(tcgen05.Field.ACCUMULATE, True) + kvh0.release() + # k_sub=1 + qh1 = qc.wait_and_advance(); qh1.release() + kvh1 = kvc.wait_and_advance() + for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True): + cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh1.index)], tStS0) + qk_mma.set(tcgen05.Field.ACCUMULATE, True) + kvh1.release() + cute.arch.fence_view_async_tmem_store() + + # Epilogue warps just allocate/free TMEM + if warp_idx < self.mma_warp_id: + tmem.allocate(64) + tmem.wait_for_alloc() + tmem.relinquish_alloc_permit() + tmem.free(tmem.retrieve_ptr(self.qk_acc_dtype)) + + +def test(): + torch.manual_seed(42) + hd, n, m = 512, 128, 128 + + q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') + k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device='cuda') + s_out = torch.zeros(1, dtype=torch.float32, device='cuda') # dummy + + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + mS = ct.from_dlpack(s_out).mark_layout_dynamic(leading_dim=ct.get_leading_dim(s_out)) + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + kernel = QkOnly512() + print('Compiling QK-only hd=512...', flush=True) + t0 = time.time() + compiled = cute.compile(kernel, mQ, mK, mS, stream) + t1 = time.time() + print(f'Compilation took {t1-t0:.1f}s', flush=True) + + compiled(mQ, mK, mS, stream) + torch.cuda.synchronize() + print('QK-only hd=512: SUCCESS') + +if __name__ == '__main__': + test()