D1: N-tile support for HEAD_DIM>256

- pv_n_tile = min(head_dim, 256) — MMA instruction N limit
- n_pv_tiles = head_dim // pv_n_tile — outer loop count
- V FMHA layout uses pv_n_tile (not head_dim) for N-tile slicing
- Test loops over N-tiles at Python level, kernel processes (128, pv_n_tile)
- For hd=512: 2 kernel launches with V[:,0:256] and V[:,256:512]
This commit is contained in:
2026-05-23 03:22:23 +00:00
parent f2dced88a3
commit b249b8f135
2 changed files with 66 additions and 23 deletions

View File

@@ -22,6 +22,8 @@ class FmhaKernel:
self.head_dim = head_dim
self.s_k = s_k
self.n_kv_tiles = s_k // 128
self.pv_n_tile = min(head_dim, 256)
self.n_pv_tiles = head_dim // self.pv_n_tile
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
@@ -36,10 +38,10 @@ class FmhaKernel:
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (128, 128, qk_ik * 4)
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
self.pv_mma_tiler = (128, self.head_dim, pv_ik * (128 // pv_ik))
self.pv_mma_tiler = (128, self.pv_n_tile, pv_ik * (128 // pv_ik))
self.mma_tiler = self.qk_mma_tiler
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), self.head_dim, self.qk_mma_tiler[2])
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), self.pv_n_tile, self.qk_mma_tiler[2])
self.c_layout = LayoutEnum.ROW_MAJOR
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, self.c_layout, self.o_dtype)
self.num_ab_stage = 1; self.num_acc_stage = 1
@@ -76,17 +78,20 @@ class FmhaKernel:
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
# V FMHA layout: K-major (pv_n_tile, s_k) for PV GEMM
# When head_dim > 256, V_tile has pv_n_tile columns, not head_dim
v_n = self.pv_n_tile
v_fmha = cute.make_tensor(
v.iterator,
cute.make_layout(
(self.head_dim, self.s_k, 1),
stride=(1, self.head_dim, self.head_dim * self.s_k),
(v_n, self.s_k, 1),
stride=(1, v_n, v_n * self.s_k),
),
)
self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major, self.qk_acc_dtype, self.cta_group, (128,self.head_dim), tcgen05.OperandSource.TMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major, self.qk_acc_dtype, self.cta_group, (128,self.pv_n_tile), tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0)); v_s = cute.slice_(self.v_smem_s,(None,None,None,0))
tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape)
@@ -256,7 +261,7 @@ class FmhaKernel:
tTMEM_LOADtO = thr_tmem_load_o.partition_S(tOtO_i)
tTMEM_LOADcO = thr_tmem_load_o.partition_D(tOcO_i)
tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_i)
n_corr_tiles = self.head_dim // corr_tile_size
n_corr_tiles = self.pv_n_tile // corr_tile_size
for kt in range(self.n_kv_tiles):
si_handle = s_cons.wait_and_advance()

View File

@@ -3,14 +3,23 @@ FMHA v3 Stage D1: Parameterized HEAD_DIM (64 → 512).
Tests the FmhaKernel class from dsv4.kernels.attention.fmha with variable head_dim.
- HEAD_DIM=64: regression test (must match Stage C results)
- HEAD_DIM=512: DSV4 production config (TMEM budget is the key risk)
- HEAD_DIM=256: MMA instruction max N (single PV tile)
- HEAD_DIM=512: DSV4 production config (2 PV N-tiles, handled at Python level)
For HEAD_DIM > 256, the PV GEMM exceeds the tcgen05 MMA instruction's N=256 limit.
The kernel processes (128, min(hd, 256)) per launch. For hd=512, we launch twice:
- Pass 0: V[:, 0:256], output[:, 0:256]
- Pass 1: V[:, 256:512], output[:, 256:512]
QK and softmax run in each pass (2× work for hd=512), but QK is small relative to PV.
"""
import torch, math, sys
import torch, math
import cutlass.cute as cute
import cutlass.torch as ct
import cuda.bindings.driver as cuda
from dsv4.kernels.attention.fmha import FmhaKernel
def test_head_dim(hd, n_kv):
"""Test FMHA kernel at given head_dim and KV length."""
m = 128 # M tile is always 128
@@ -19,7 +28,6 @@ def test_head_dim(hd, n_kv):
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(n_kv, hd, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(n_kv, hd, dtype=torch.bfloat16, device='cuda')
v_kernel = v.unsqueeze(-1)
c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
# FP32 reference
@@ -30,18 +38,45 @@ def test_head_dim(hd, n_kv):
attn = torch.softmax(attn, dim=-1)
ref = attn @ v.float()
kernel = FmhaKernel(head_dim=hd, s_k=n_kv)
pv_n_tile = kernel.pv_n_tile
n_pv_tiles = kernel.n_pv_tiles
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
# Compile once (kernel only sees pv_n_tile width)
# Use first tile for compilation
v_tile = v[:, 0:pv_n_tile].contiguous()
v_kernel = v_tile.unsqueeze(-1)
c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
kernel = FmhaKernel(head_dim=hd, s_k=n_kv)
print(f'hd={hd}, n={n_kv}: Compiling...', flush=True)
print(f'hd={hd}, n={n_kv} (pv_n_tile={pv_n_tile}, n_pv_tiles={n_pv_tiles}): Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
compiled(mQ, mK, mV, mC, stream)
torch.cuda.synchronize()
# Run each N-tile
for nt in range(n_pv_tiles):
v_start = nt * pv_n_tile
v_end = v_start + pv_n_tile
v_tile = v[:, v_start:v_end].contiguous()
v_kernel = v_tile.unsqueeze(-1)
c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
compiled(mQ, mK, mV, mC, stream)
torch.cuda.synchronize()
c[:, v_start:v_end, :] = c_tile
# Compare
out = c[:, :, 0].float()
cos = torch.nn.functional.cosine_similarity(
out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)
@@ -60,18 +95,21 @@ def test():
# Regression: hd=64 must match Stage C results (cos ~0.973)
print("--- Regression: HEAD_DIM=64 ---")
cos64_128 = test_head_dim(64, 128)
cos64_256 = test_head_dim(64, 256)
cos64 = test_head_dim(64, 128)
# DSV4 production: hd=512
print("\n--- Production: HEAD_DIM=512 ---")
cos512_128 = test_head_dim(512, 128)
# hd=256: single PV tile at MMA instruction max
print("\n--- HEAD_DIM=256 (single PV tile) ---")
cos256 = test_head_dim(256, 128)
# hd=512: 2 PV tiles (DSV4 production)
print("\n--- HEAD_DIM=512 (2 PV tiles) ---")
cos512 = test_head_dim(512, 128)
# Summary
print("\n=== Summary ===")
print(f"hd=64, n=128: cos={cos64_128:.6f} {'PASS' if cos64_128 >= 0.97 else 'FAIL'}")
print(f"hd=64, n=256: cos={cos64_256:.6f} {'PASS' if cos64_256 >= 0.97 else 'FAIL'}")
print(f"hd=512, n=128: cos={cos512_128:.6f} {'PASS' if cos512_128 >= 0.97 else 'FAIL'}")
print(f"hd=64, n=128: cos={cos64:.6f} {'PASS' if cos64 >= 0.97 else 'FAIL'}")
print(f"hd=256, n=128: cos={cos256:.6f} {'PASS' if cos256 >= 0.97 else 'FAIL'}")
print(f"hd=512, n=128: cos={cos512:.6f} {'PASS' if cos512 >= 0.97 else 'FAIL'}")
if __name__ == '__main__':