D1: N-tile support for HEAD_DIM>256
- pv_n_tile = min(head_dim, 256) — MMA instruction N limit - n_pv_tiles = head_dim // pv_n_tile — outer loop count - V FMHA layout uses pv_n_tile (not head_dim) for N-tile slicing - Test loops over N-tiles at Python level, kernel processes (128, pv_n_tile) - For hd=512: 2 kernel launches with V[:,0:256] and V[:,256:512]
This commit is contained in:
@@ -22,6 +22,8 @@ class FmhaKernel:
|
||||
self.head_dim = head_dim
|
||||
self.s_k = s_k
|
||||
self.n_kv_tiles = s_k // 128
|
||||
self.pv_n_tile = min(head_dim, 256)
|
||||
self.n_pv_tiles = head_dim // self.pv_n_tile
|
||||
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
|
||||
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
|
||||
self.use_2cta_instrs = False; self.epilog_sync_bar_id = 1
|
||||
@@ -36,10 +38,10 @@ class FmhaKernel:
|
||||
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
self.qk_mma_tiler = (128, 128, qk_ik * 4)
|
||||
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
|
||||
self.pv_mma_tiler = (128, self.head_dim, pv_ik * (128 // pv_ik))
|
||||
self.pv_mma_tiler = (128, self.pv_n_tile, pv_ik * (128 // pv_ik))
|
||||
self.mma_tiler = self.qk_mma_tiler
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
||||
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), self.head_dim, self.qk_mma_tiler[2])
|
||||
self.cta_tile_shape_mnk = (self.qk_mma_tiler[0]//cute.size(qk_mma.thr_id.shape), self.pv_n_tile, self.qk_mma_tiler[2])
|
||||
self.c_layout = LayoutEnum.ROW_MAJOR
|
||||
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(self.cta_tile_shape_mnk, False, self.c_layout, self.o_dtype)
|
||||
self.num_ab_stage = 1; self.num_acc_stage = 1
|
||||
@@ -76,17 +78,20 @@ class FmhaKernel:
|
||||
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
|
||||
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
|
||||
# V FMHA layout: K-major (pv_n_tile, s_k) for PV GEMM
|
||||
# When head_dim > 256, V_tile has pv_n_tile columns, not head_dim
|
||||
v_n = self.pv_n_tile
|
||||
v_fmha = cute.make_tensor(
|
||||
v.iterator,
|
||||
cute.make_layout(
|
||||
(self.head_dim, self.s_k, 1),
|
||||
stride=(1, self.head_dim, self.head_dim * self.s_k),
|
||||
(v_n, self.s_k, 1),
|
||||
stride=(1, v_n, v_n * self.s_k),
|
||||
),
|
||||
)
|
||||
self.v_major = LayoutEnum.from_tensor(v_fmha).mma_major_mode()
|
||||
self.c_layout = LayoutEnum.from_tensor(c)
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major, self.qk_acc_dtype, self.cta_group, (128,self.head_dim), tcgen05.OperandSource.TMEM)
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major, self.qk_acc_dtype, self.cta_group, (128,self.pv_n_tile), tcgen05.OperandSource.TMEM)
|
||||
self._setup(qk_mma, pv_mma)
|
||||
q_s = cute.slice_(self.q_smem_s,(None,None,None,0)); k_s = cute.slice_(self.k_smem_s,(None,None,None,0)); v_s = cute.slice_(self.v_smem_s,(None,None,None,0))
|
||||
tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_layout_vmnk.shape)
|
||||
@@ -256,7 +261,7 @@ class FmhaKernel:
|
||||
tTMEM_LOADtO = thr_tmem_load_o.partition_S(tOtO_i)
|
||||
tTMEM_LOADcO = thr_tmem_load_o.partition_D(tOcO_i)
|
||||
tTMEM_STOREtO = thr_tmem_store_o.partition_D(tOtO_i)
|
||||
n_corr_tiles = self.head_dim // corr_tile_size
|
||||
n_corr_tiles = self.pv_n_tile // corr_tile_size
|
||||
|
||||
for kt in range(self.n_kv_tiles):
|
||||
si_handle = s_cons.wait_and_advance()
|
||||
|
||||
@@ -3,14 +3,23 @@ FMHA v3 Stage D1: Parameterized HEAD_DIM (64 → 512).
|
||||
|
||||
Tests the FmhaKernel class from dsv4.kernels.attention.fmha with variable head_dim.
|
||||
- HEAD_DIM=64: regression test (must match Stage C results)
|
||||
- HEAD_DIM=512: DSV4 production config (TMEM budget is the key risk)
|
||||
- HEAD_DIM=256: MMA instruction max N (single PV tile)
|
||||
- HEAD_DIM=512: DSV4 production config (2 PV N-tiles, handled at Python level)
|
||||
|
||||
For HEAD_DIM > 256, the PV GEMM exceeds the tcgen05 MMA instruction's N=256 limit.
|
||||
The kernel processes (128, min(hd, 256)) per launch. For hd=512, we launch twice:
|
||||
- Pass 0: V[:, 0:256], output[:, 0:256]
|
||||
- Pass 1: V[:, 256:512], output[:, 256:512]
|
||||
|
||||
QK and softmax run in each pass (2× work for hd=512), but QK is small relative to PV.
|
||||
"""
|
||||
import torch, math, sys
|
||||
import torch, math
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as ct
|
||||
import cuda.bindings.driver as cuda
|
||||
from dsv4.kernels.attention.fmha import FmhaKernel
|
||||
|
||||
|
||||
def test_head_dim(hd, n_kv):
|
||||
"""Test FMHA kernel at given head_dim and KV length."""
|
||||
m = 128 # M tile is always 128
|
||||
@@ -19,7 +28,6 @@ def test_head_dim(hd, n_kv):
|
||||
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
k = torch.randn(n_kv, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
v = torch.randn(n_kv, hd, dtype=torch.bfloat16, device='cuda')
|
||||
v_kernel = v.unsqueeze(-1)
|
||||
c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
# FP32 reference
|
||||
@@ -30,18 +38,45 @@ def test_head_dim(hd, n_kv):
|
||||
attn = torch.softmax(attn, dim=-1)
|
||||
ref = attn @ v.float()
|
||||
|
||||
kernel = FmhaKernel(head_dim=hd, s_k=n_kv)
|
||||
pv_n_tile = kernel.pv_n_tile
|
||||
n_pv_tiles = kernel.n_pv_tiles
|
||||
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
# Compile once (kernel only sees pv_n_tile width)
|
||||
# Use first tile for compilation
|
||||
v_tile = v[:, 0:pv_n_tile].contiguous()
|
||||
v_kernel = v_tile.unsqueeze(-1)
|
||||
c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||||
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
|
||||
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
|
||||
|
||||
kernel = FmhaKernel(head_dim=hd, s_k=n_kv)
|
||||
print(f'hd={hd}, n={n_kv}: Compiling...', flush=True)
|
||||
print(f'hd={hd}, n={n_kv} (pv_n_tile={pv_n_tile}, n_pv_tiles={n_pv_tiles}): Compiling...', flush=True)
|
||||
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
|
||||
compiled(mQ, mK, mV, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Run each N-tile
|
||||
for nt in range(n_pv_tiles):
|
||||
v_start = nt * pv_n_tile
|
||||
v_end = v_start + pv_n_tile
|
||||
v_tile = v[:, v_start:v_end].contiguous()
|
||||
v_kernel = v_tile.unsqueeze(-1)
|
||||
c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||||
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
|
||||
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
|
||||
|
||||
compiled(mQ, mK, mV, mC, stream)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
c[:, v_start:v_end, :] = c_tile
|
||||
|
||||
# Compare
|
||||
out = c[:, :, 0].float()
|
||||
cos = torch.nn.functional.cosine_similarity(
|
||||
out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)
|
||||
@@ -60,18 +95,21 @@ def test():
|
||||
|
||||
# Regression: hd=64 must match Stage C results (cos ~0.973)
|
||||
print("--- Regression: HEAD_DIM=64 ---")
|
||||
cos64_128 = test_head_dim(64, 128)
|
||||
cos64_256 = test_head_dim(64, 256)
|
||||
cos64 = test_head_dim(64, 128)
|
||||
|
||||
# DSV4 production: hd=512
|
||||
print("\n--- Production: HEAD_DIM=512 ---")
|
||||
cos512_128 = test_head_dim(512, 128)
|
||||
# hd=256: single PV tile at MMA instruction max
|
||||
print("\n--- HEAD_DIM=256 (single PV tile) ---")
|
||||
cos256 = test_head_dim(256, 128)
|
||||
|
||||
# hd=512: 2 PV tiles (DSV4 production)
|
||||
print("\n--- HEAD_DIM=512 (2 PV tiles) ---")
|
||||
cos512 = test_head_dim(512, 128)
|
||||
|
||||
# Summary
|
||||
print("\n=== Summary ===")
|
||||
print(f"hd=64, n=128: cos={cos64_128:.6f} {'PASS' if cos64_128 >= 0.97 else 'FAIL'}")
|
||||
print(f"hd=64, n=256: cos={cos64_256:.6f} {'PASS' if cos64_256 >= 0.97 else 'FAIL'}")
|
||||
print(f"hd=512, n=128: cos={cos512_128:.6f} {'PASS' if cos512_128 >= 0.97 else 'FAIL'}")
|
||||
print(f"hd=64, n=128: cos={cos64:.6f} {'PASS' if cos64 >= 0.97 else 'FAIL'}")
|
||||
print(f"hd=256, n=128: cos={cos256:.6f} {'PASS' if cos256 >= 0.97 else 'FAIL'}")
|
||||
print(f"hd=512, n=128: cos={cos512:.6f} {'PASS' if cos512 >= 0.97 else 'FAIL'}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Reference in New Issue
Block a user