From b249b8f135272e8a89379e458f68951cbc83ec2d Mon Sep 17 00:00:00 2001 From: biondizzle Date: Sat, 23 May 2026 03:22:23 +0000 Subject: [PATCH] D1: N-tile support for HEAD_DIM>256 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - pv_n_tile = min(head_dim, 256) — MMA instruction N limit - n_pv_tiles = head_dim // pv_n_tile — outer loop count - V FMHA layout uses pv_n_tile (not head_dim) for N-tile slicing - Test loops over N-tiles at Python level, kernel processes (128, pv_n_tile) - For hd=512: 2 kernel launches with V[:,0:256] and V[:,256:512] --- dsv4/kernels/attention/fmha.py | 17 ++++--- tests/unit/test_fmha_v3_stage_d1.py | 72 ++++++++++++++++++++++------- 2 files changed, 66 insertions(+), 23 deletions(-) diff --git a/dsv4/kernels/attention/fmha.py b/dsv4/kernels/attention/fmha.py index 58d37fce..48537701 100644 --- a/dsv4/kernels/attention/fmha.py +++ b/dsv4/kernels/attention/fmha.py @@ -22,6 +22,8 @@ class FmhaKernel: self.head_dim = head_dim self.s_k = s_k self.n_kv_tiles = s_k // 128 + self.pv_n_tile = min(head_dim, 256) + self.n_pv_tiles = head_dim // self.pv_n_tile self.acc_dtype = Float32; self.qk_acc_dtype = Float32 self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16 self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1 @@ -36,10 +38,10 @@ class FmhaKernel: qk_ik = cute.size(qk_mma.shape_mnk, mode=[2]) self.qk_mma_tiler = (128, 128, qk_ik * 4) pv_ik = cute.size(pv_mma.shape_mnk, mode=[2]) - self.pv_mma_tiler = (128, self.head_dim, pv_ik * (128 // pv_ik)) + self.pv_mma_tiler = (128, self.pv_n_tile, pv_ik * (128 // pv_ik)) self.mma_tiler = self.qk_mma_tiler self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)) - self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), self.head_dim, self.qk_mma_tiler[2]) + self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), self.pv_n_tile, self.qk_mma_tiler[2]) self.c_layout = LayoutEnum.ROW_MAJOR self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, self.c_layout, self.o_dtype) self.num_ab_stage = 1; self.num_acc_stage = 1 @@ -76,17 +78,20 @@ class FmhaKernel: self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype self.a_major = LayoutEnum.from_tensor(q).mma_major_mode() self.b_major = LayoutEnum.from_tensor(k).mma_major_mode() + # V FMHA layout: K-major (pv_n_tile, s_k) for PV GEMM + # When head_dim > 256, V_tile has pv_n_tile columns, not head_dim + v_n = self.pv_n_tile v_fmha = cute.make_tensor( v.iterator, cute.make_layout( - (self.head_dim, self.s_k, 1), - stride=(1, self.head_dim, self.head_dim * self.s_k), + (v_n, self.s_k, 1), + stride=(1, v_n, v_n * self.s_k), ), ) self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode() self.c_layout = LayoutEnum.from_tensor(c) qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM) - pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major, self.qk_acc_dtype, self.cta_group, (128,self.head_dim), tcgen05.OperandSource.TMEM) + pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major, self.qk_acc_dtype, self.cta_group, (128,self.pv_n_tile), tcgen05.OperandSource.TMEM) self._setup(qk_mma, pv_mma) q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0)); v_s = cute.slice_(self.v_smem_s,(None,None,None,0)) tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape) @@ -256,7 +261,7 @@ class FmhaKernel: tTMEM_LOADtO = thr_tmem_load_o.partition_S(tOtO_i) tTMEM_LOADcO = thr_tmem_load_o.partition_D(tOcO_i) tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_i) - n_corr_tiles = self.head_dim // corr_tile_size + n_corr_tiles = self.pv_n_tile // corr_tile_size for kt in range(self.n_kv_tiles): si_handle = s_cons.wait_and_advance() diff --git a/tests/unit/test_fmha_v3_stage_d1.py b/tests/unit/test_fmha_v3_stage_d1.py index 24d5baec..985b67fe 100644 --- a/tests/unit/test_fmha_v3_stage_d1.py +++ b/tests/unit/test_fmha_v3_stage_d1.py @@ -3,14 +3,23 @@ FMHA v3 Stage D1: Parameterized HEAD_DIM (64 → 512). Tests the FmhaKernel class from dsv4.kernels.attention.fmha with variable head_dim. - HEAD_DIM=64: regression test (must match Stage C results) -- HEAD_DIM=512: DSV4 production config (TMEM budget is the key risk) +- HEAD_DIM=256: MMA instruction max N (single PV tile) +- HEAD_DIM=512: DSV4 production config (2 PV N-tiles, handled at Python level) + +For HEAD_DIM > 256, the PV GEMM exceeds the tcgen05 MMA instruction's N=256 limit. +The kernel processes (128, min(hd, 256)) per launch. For hd=512, we launch twice: + - Pass 0: V[:, 0:256], output[:, 0:256] + - Pass 1: V[:, 256:512], output[:, 256:512] + +QK and softmax run in each pass (2× work for hd=512), but QK is small relative to PV. """ -import torch, math, sys +import torch, math import cutlass.cute as cute import cutlass.torch as ct import cuda.bindings.driver as cuda from dsv4.kernels.attention.fmha import FmhaKernel + def test_head_dim(hd, n_kv): """Test FMHA kernel at given head_dim and KV length.""" m = 128 # M tile is always 128 @@ -19,7 +28,6 @@ def test_head_dim(hd, n_kv): q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda') k = torch.randn(n_kv, hd, 1, dtype=torch.bfloat16, device='cuda') v = torch.randn(n_kv, hd, dtype=torch.bfloat16, device='cuda') - v_kernel = v.unsqueeze(-1) c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda') # FP32 reference @@ -30,18 +38,45 @@ def test_head_dim(hd, n_kv): attn = torch.softmax(attn, dim=-1) ref = attn @ v.float() + kernel = FmhaKernel(head_dim=hd, s_k=n_kv) + pv_n_tile = kernel.pv_n_tile + n_pv_tiles = kernel.n_pv_tiles + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + # Compile once (kernel only sees pv_n_tile width) + # Use first tile for compilation + v_tile = v[:, 0:pv_n_tile].contiguous() + v_kernel = v_tile.unsqueeze(-1) + c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) - mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c)) - stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile)) - kernel = FmhaKernel(head_dim=hd, s_k=n_kv) - print(f'hd={hd}, n={n_kv}: Compiling...', flush=True) + print(f'hd={hd}, n={n_kv} (pv_n_tile={pv_n_tile}, n_pv_tiles={n_pv_tiles}): Compiling...', flush=True) compiled = cute.compile(kernel, mQ, mK, mV, mC, stream) - compiled(mQ, mK, mV, mC, stream) - torch.cuda.synchronize() + # Run each N-tile + for nt in range(n_pv_tiles): + v_start = nt * pv_n_tile + v_end = v_start + pv_n_tile + v_tile = v[:, v_start:v_end].contiguous() + v_kernel = v_tile.unsqueeze(-1) + c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda') + + mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q)) + mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k)) + mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel)) + mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile)) + + compiled(mQ, mK, mV, mC, stream) + torch.cuda.synchronize() + + c[:, v_start:v_end, :] = c_tile + + # Compare out = c[:, :, 0].float() cos = torch.nn.functional.cosine_similarity( out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0) @@ -60,18 +95,21 @@ def test(): # Regression: hd=64 must match Stage C results (cos ~0.973) print("--- Regression: HEAD_DIM=64 ---") - cos64_128 = test_head_dim(64, 128) - cos64_256 = test_head_dim(64, 256) + cos64 = test_head_dim(64, 128) - # DSV4 production: hd=512 - print("\n--- Production: HEAD_DIM=512 ---") - cos512_128 = test_head_dim(512, 128) + # hd=256: single PV tile at MMA instruction max + print("\n--- HEAD_DIM=256 (single PV tile) ---") + cos256 = test_head_dim(256, 128) + + # hd=512: 2 PV tiles (DSV4 production) + print("\n--- HEAD_DIM=512 (2 PV tiles) ---") + cos512 = test_head_dim(512, 128) # Summary print("\n=== Summary ===") - print(f"hd=64, n=128: cos={cos64_128:.6f} {'PASS' if cos64_128 >= 0.97 else 'FAIL'}") - print(f"hd=64, n=256: cos={cos64_256:.6f} {'PASS' if cos64_256 >= 0.97 else 'FAIL'}") - print(f"hd=512, n=128: cos={cos512_128:.6f} {'PASS' if cos512_128 >= 0.97 else 'FAIL'}") + print(f"hd=64, n=128: cos={cos64:.6f} {'PASS' if cos64 >= 0.97 else 'FAIL'}") + print(f"hd=256, n=128: cos={cos256:.6f} {'PASS' if cos256 >= 0.97 else 'FAIL'}") + print(f"hd=512, n=128: cos={cos512:.6f} {'PASS' if cos512 >= 0.97 else 'FAIL'}") if __name__ == '__main__':