D1.4: Add hd=512 QK-only and standalone test for compilation debugging

This commit is contained in:
2026-05-24 14:19:26 +00:00
parent 592873b560
commit 625837fd44
2 changed files with 251 additions and 0 deletions

View File

@@ -0,0 +1,87 @@
"""D1 test: HEAD_DIM=512 only (faster iteration on compilation issues)."""
import torch, math, sys
import cutlass.cute as cute
import cutlass.torch as ct
import cuda.bindings.driver as cuda
from dsv4.kernels.attention.fmha import FmhaKernel
def test():
torch.manual_seed(42)
hd, n = 512, 128
m = 128
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(n, hd, dtype=torch.bfloat16, device='cuda')
qf = q[:, :, 0].float()
kf = k[:, :, 0].float()
scale = 1.0 / math.sqrt(hd)
attn_max = (qf @ kf.T * scale).max(dim=-1, keepdim=True)[0]
attn_exp = torch.exp(qf @ kf.T * scale - attn_max)
attn_sum = attn_exp.sum(dim=-1, keepdim=True)
ref_unnorm = attn_exp @ v.float()
ref_lse = (torch.log(attn_sum.squeeze(-1)) + attn_max.squeeze(-1))[0].item()
lse_tensor = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
kernel = FmhaKernel(head_dim=hd, s_k=n, use_smem_p=False)
pv_n_tile = kernel.pv_n_tile
n_pv_tiles = kernel.n_pv_tiles
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
print(f'hd={hd}, pv_n_tile={pv_n_tile}, n_pv_tiles={n_pv_tiles}, n_k_sub_tiles={kernel.n_k_sub_tiles}, k_tile={kernel.k_tile}', flush=True)
print(f'Compiling first PV tile...', flush=True)
# Only compile the first PV tile to isolate compilation issues
v_tile = v[:, 0:pv_n_tile].contiguous()
v_kernel = v_tile.unsqueeze(-1)
c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor))
import time
t0 = time.time()
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE)
t1 = time.time()
print(f'Compilation took {t1-t0:.1f}s', flush=True)
# Run all PV tiles
lse_val = None
c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
for nt in range(n_pv_tiles):
v_start = nt * pv_n_tile
v_end = v_start + pv_n_tile
v_tile = v[:, v_start:v_end].contiguous()
v_kernel = v_tile.unsqueeze(-1)
c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
lse_tensor.zero_()
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor))
compiled(mQ, mK, mV, mC, stream, mLSE)
torch.cuda.synchronize()
print(f' PV tile {nt}: done', flush=True)
c[:, v_start:v_end, :] = c_tile
if nt == 0:
lse_val = lse_tensor[0, 0, 0].item()
out_unnorm = c[:, :, 0].float()
cos_unnorm = torch.nn.functional.cosine_similarity(
out_unnorm.flatten().unsqueeze(0), ref_unnorm.flatten().unsqueeze(0)
).item()
lse_err = abs(lse_val - ref_lse) if lse_val is not None else float('inf')
status = "PASS" if cos_unnorm >= 0.99 else "FAIL"
print(f'hd={hd}: cos_unnorm {cos_unnorm:.6f} lse_err {lse_err:.6f} {status}')
if __name__ == '__main__':
test()

164
tests/unit/test_d1_qk512.py Normal file
View File

@@ -0,0 +1,164 @@
"""Minimal hd=512 test: ONLY QK GEMM, no softmax, no PV.
Goal: isolate whether the compilation hang is from QK or softmax/PV."""
import torch, math, time
import cutlass.cute as cute
import cutlass.torch as ct
import cuda.bindings.driver as cuda
from cutlass import BFloat16, Float32
from cutlass.cute.nvgpu import tcgen05
from cutlass.utils import LayoutEnum
import cutlass.utils as utils
import cutlass.pipeline as pipeline
from cutlass import const_expr
class QkOnly512:
def __init__(self):
self.head_dim = 512
self.k_tile = 256
self.n_k_sub_tiles = 2
self.kv_stage = 1
self.q_stage = 1
self.q_dtype = BFloat16
self.qk_acc_dtype = Float32
self.cta_group = tcgen05.CtaGroup.ONE
self.cluster_shape_mn = (1, 1)
self.qk_mma_tiler = (128, 128, self.k_tile)
self.threads_per_cta = 192
self.mma_warp_id = 4
self.tma_warp_id = 5
self.epilogue_warp_id = (0,1,2,3)
def _setup(self, qk_mma):
self.q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.q_dtype, self.q_stage)
self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage)
cta = cute.size(qk_mma.thr_id.shape)
q_s = cute.slice_(self.q_smem_s,(None,None,None,0))
k_s = cute.slice_(self.k_smem_s,(None,None,None,0))
self.q_tx_bytes = cute.size_in_bytes(self.q_dtype, q_s) * cta
self.kv_tx_bytes = cute.size_in_bytes(self.q_dtype, k_s) * cta
@cute.jit
def __call__(self, q, k, s_out, stream):
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM)
self._setup(qk_mma)
q_s = cute.slice_(self.q_smem_s,(None,None,None,0))
k_s = cute.slice_(self.k_smem_s,(None,None,None,0))
tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_shape_mn)
tma_k,mK = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,qk_mma.thr_id),k,k_s,self.qk_mma_tiler,qk_mma,self.cluster_shape_mn)
self._kernel(qk_mma, tma_q, mQ, tma_k, mK, s_out).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream)
@cute.kernel
def _kernel(self, qk_mma, tma_q, mQ, tma_k, mK, s_out):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx,_,_ = cute.arch.thread_idx()
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k)
@cute.struct
class SS:
q_bar: cute.struct.MemRange[cutlass.Int64, self.q_stage*2]
kv_bar: cute.struct.MemRange[cutlass.Int64, self.kv_stage*2]
holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)),defer_sync=True).make_participants()
kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)),defer_sync=True).make_participants()
tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=self.epilogue_warp_id[0],is_two_cta=cute.size(qk_mma.thr_id.shape)==2)
pipeline.pipeline_init_arrive(cluster_shape_mn=cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)),is_relaxed=True)
sQ = smem.allocate_tensor(element_type=self.q_dtype,layout=self.q_smem_s.outer,byte_alignment=128,swizzle=self.q_smem_s.inner)
sK = smem.allocate_tensor(element_type=self.q_dtype,layout=self.k_smem_s.outer,byte_alignment=128,swizzle=self.k_smem_s.inner)
gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None))
gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None))
qk_thr = qk_mma.get_slice(0)
tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK)
a_lay = cute.make_layout((1,))
b_lay = cute.make_layout((1,))
tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(sQ,0,3),cute.group_modes(tCgQ,0,3))
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
pipeline.pipeline_init_wait(cluster_shape_mn=cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)))
# ===== TMA LOAD warp =====
if warp_idx == self.tma_warp_id:
qp.reset()
kvp.reset()
# k_sub=0
qh0 = qp.acquire_and_advance()
cute.copy(tma_q, tAgQ[(None, cutlass.Int32(0))], tAsQ[(None, qh0.index)], tma_bar_ptr=qh0.barrier)
kvh0 = kvp.acquire_and_advance()
cute.copy(tma_k, tBgK[(None, cutlass.Int32(0))], tBsK[(None, kvh0.index)], tma_bar_ptr=kvh0.barrier)
# k_sub=1
qh1 = qp.acquire_and_advance()
cute.copy(tma_q, tAgQ[(None, cutlass.Int32(1))], tAsQ[(None, qh1.index)], tma_bar_ptr=qh1.barrier)
kvh1 = kvp.acquire_and_advance()
cute.copy(tma_k, tBgK[(None, cutlass.Int32(1))], tBsK[(None, kvh1.index)], tma_bar_ptr=kvh1.barrier)
qp.tail()
kvp.tail()
# ===== MMA warp =====
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
# k_sub=0
qh0 = qc.wait_and_advance(); qh0.release()
kvh0 = kvc.wait_and_advance()
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh0.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
kvh0.release()
# k_sub=1
qh1 = qc.wait_and_advance(); qh1.release()
kvh1 = kvc.wait_and_advance()
for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh1.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
kvh1.release()
cute.arch.fence_view_async_tmem_store()
# Epilogue warps just allocate/free TMEM
if warp_idx < self.mma_warp_id:
tmem.allocate(64)
tmem.wait_for_alloc()
tmem.relinquish_alloc_permit()
tmem.free(tmem.retrieve_ptr(self.qk_acc_dtype))
def test():
torch.manual_seed(42)
hd, n, m = 512, 128, 128
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device='cuda')
s_out = torch.zeros(1, dtype=torch.float32, device='cuda') # dummy
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mS = ct.from_dlpack(s_out).mark_layout_dynamic(leading_dim=ct.get_leading_dim(s_out))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = QkOnly512()
print('Compiling QK-only hd=512...', flush=True)
t0 = time.time()
compiled = cute.compile(kernel, mQ, mK, mS, stream)
t1 = time.time()
print(f'Compilation took {t1-t0:.1f}s', flush=True)
compiled(mQ, mK, mS, stream)
torch.cuda.synchronize()
print('QK-only hd=512: SUCCESS')
if __name__ == '__main__':
test()