D1.4: Add hd=512 QK-only and standalone test for compilation debugging
This commit is contained in:
87
tests/unit/test_d1_hd512_only.py
Normal file
87
tests/unit/test_d1_hd512_only.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""D1 test: HEAD_DIM=512 only (faster iteration on compilation issues)."""
|
||||
import torch, math, sys
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as ct
|
||||
import cuda.bindings.driver as cuda
|
||||
from dsv4.kernels.attention.fmha import FmhaKernel
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
hd, n = 512, 128
|
||||
m = 128
|
||||
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
v = torch.randn(n, hd, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
qf = q[:, :, 0].float()
|
||||
kf = k[:, :, 0].float()
|
||||
scale = 1.0 / math.sqrt(hd)
|
||||
attn_max = (qf @ kf.T * scale).max(dim=-1, keepdim=True)[0]
|
||||
attn_exp = torch.exp(qf @ kf.T * scale - attn_max)
|
||||
attn_sum = attn_exp.sum(dim=-1, keepdim=True)
|
||||
ref_unnorm = attn_exp @ v.float()
|
||||
ref_lse = (torch.log(attn_sum.squeeze(-1)) + attn_max.squeeze(-1))[0].item()
|
||||
|
||||
lse_tensor = torch.zeros(m, 1, 1, dtype=torch.float32, device='cuda')
|
||||
kernel = FmhaKernel(head_dim=hd, s_k=n, use_smem_p=False)
|
||||
pv_n_tile = kernel.pv_n_tile
|
||||
n_pv_tiles = kernel.n_pv_tiles
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
print(f'hd={hd}, pv_n_tile={pv_n_tile}, n_pv_tiles={n_pv_tiles}, n_k_sub_tiles={kernel.n_k_sub_tiles}, k_tile={kernel.k_tile}', flush=True)
|
||||
print(f'Compiling first PV tile...', flush=True)
|
||||
|
||||
# Only compile the first PV tile to isolate compilation issues
|
||||
v_tile = v[:, 0:pv_n_tile].contiguous()
|
||||
v_kernel = v_tile.unsqueeze(-1)
|
||||
c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||||
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
|
||||
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
|
||||
mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor))
|
||||
|
||||
import time
|
||||
t0 = time.time()
|
||||
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream, mLSE)
|
||||
t1 = time.time()
|
||||
print(f'Compilation took {t1-t0:.1f}s', flush=True)
|
||||
|
||||
# Run all PV tiles
|
||||
lse_val = None
|
||||
c = torch.zeros(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
for nt in range(n_pv_tiles):
|
||||
v_start = nt * pv_n_tile
|
||||
v_end = v_start + pv_n_tile
|
||||
v_tile = v[:, v_start:v_end].contiguous()
|
||||
v_kernel = v_tile.unsqueeze(-1)
|
||||
c_tile = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
|
||||
lse_tensor.zero_()
|
||||
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||||
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
|
||||
mC = ct.from_dlpack(c_tile).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c_tile))
|
||||
mLSE = ct.from_dlpack(lse_tensor).mark_layout_dynamic(leading_dim=ct.get_leading_dim(lse_tensor))
|
||||
|
||||
compiled(mQ, mK, mV, mC, stream, mLSE)
|
||||
torch.cuda.synchronize()
|
||||
print(f' PV tile {nt}: done', flush=True)
|
||||
|
||||
c[:, v_start:v_end, :] = c_tile
|
||||
if nt == 0:
|
||||
lse_val = lse_tensor[0, 0, 0].item()
|
||||
|
||||
out_unnorm = c[:, :, 0].float()
|
||||
cos_unnorm = torch.nn.functional.cosine_similarity(
|
||||
out_unnorm.flatten().unsqueeze(0), ref_unnorm.flatten().unsqueeze(0)
|
||||
).item()
|
||||
lse_err = abs(lse_val - ref_lse) if lse_val is not None else float('inf')
|
||||
|
||||
status = "PASS" if cos_unnorm >= 0.99 else "FAIL"
|
||||
print(f'hd={hd}: cos_unnorm {cos_unnorm:.6f} lse_err {lse_err:.6f} {status}')
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
164
tests/unit/test_d1_qk512.py
Normal file
164
tests/unit/test_d1_qk512.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""Minimal hd=512 test: ONLY QK GEMM, no softmax, no PV.
|
||||
Goal: isolate whether the compilation hang is from QK or softmax/PV."""
|
||||
import torch, math, time
|
||||
import cutlass.cute as cute
|
||||
import cutlass.torch as ct
|
||||
import cuda.bindings.driver as cuda
|
||||
from cutlass import BFloat16, Float32
|
||||
from cutlass.cute.nvgpu import tcgen05
|
||||
from cutlass.utils import LayoutEnum
|
||||
import cutlass.utils as utils
|
||||
import cutlass.pipeline as pipeline
|
||||
from cutlass import const_expr
|
||||
|
||||
class QkOnly512:
|
||||
def __init__(self):
|
||||
self.head_dim = 512
|
||||
self.k_tile = 256
|
||||
self.n_k_sub_tiles = 2
|
||||
self.kv_stage = 1
|
||||
self.q_stage = 1
|
||||
self.q_dtype = BFloat16
|
||||
self.qk_acc_dtype = Float32
|
||||
self.cta_group = tcgen05.CtaGroup.ONE
|
||||
self.cluster_shape_mn = (1, 1)
|
||||
self.qk_mma_tiler = (128, 128, self.k_tile)
|
||||
self.threads_per_cta = 192
|
||||
self.mma_warp_id = 4
|
||||
self.tma_warp_id = 5
|
||||
self.epilogue_warp_id = (0,1,2,3)
|
||||
|
||||
def _setup(self, qk_mma):
|
||||
self.q_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.qk_mma_tiler, self.q_dtype, self.q_stage)
|
||||
self.k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.qk_mma_tiler, self.q_dtype, self.kv_stage)
|
||||
cta = cute.size(qk_mma.thr_id.shape)
|
||||
q_s = cute.slice_(self.q_smem_s,(None,None,None,0))
|
||||
k_s = cute.slice_(self.k_smem_s,(None,None,None,0))
|
||||
self.q_tx_bytes = cute.size_in_bytes(self.q_dtype, q_s) * cta
|
||||
self.kv_tx_bytes = cute.size_in_bytes(self.q_dtype, k_s) * cta
|
||||
|
||||
@cute.jit
|
||||
def __call__(self, q, k, s_out, stream):
|
||||
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
|
||||
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(self.q_dtype, self.q_dtype, self.a_major, self.b_major, self.qk_acc_dtype, self.cta_group, (128,128), tcgen05.OperandSource.SMEM)
|
||||
self._setup(qk_mma)
|
||||
q_s = cute.slice_(self.q_smem_s,(None,None,None,0))
|
||||
k_s = cute.slice_(self.k_smem_s,(None,None,None,0))
|
||||
tma_q,mQ = cute.nvgpu.make_tiled_tma_atom_A(utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn,qk_mma.thr_id),q,q_s,self.qk_mma_tiler,qk_mma,self.cluster_shape_mn)
|
||||
tma_k,mK = cute.nvgpu.make_tiled_tma_atom_B(utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn,qk_mma.thr_id),k,k_s,self.qk_mma_tiler,qk_mma,self.cluster_shape_mn)
|
||||
self._kernel(qk_mma, tma_q, mQ, tma_k, mK, s_out).launch(grid=(1,1,1),block=[self.threads_per_cta,1,1],stream=stream)
|
||||
|
||||
@cute.kernel
|
||||
def _kernel(self, qk_mma, tma_q, mQ, tma_k, mK, s_out):
|
||||
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
|
||||
tidx,_,_ = cute.arch.thread_idx()
|
||||
if warp_idx == self.tma_warp_id:
|
||||
cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k)
|
||||
|
||||
@cute.struct
|
||||
class SS:
|
||||
q_bar: cute.struct.MemRange[cutlass.Int64, self.q_stage*2]
|
||||
kv_bar: cute.struct.MemRange[cutlass.Int64, self.kv_stage*2]
|
||||
holding: cutlass.Int32
|
||||
smem = utils.SmemAllocator(); st = smem.allocate(SS)
|
||||
|
||||
qp,qc = pipeline.PipelineTmaUmma.create(barrier_storage=st.q_bar.data_ptr(),num_stages=self.q_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.q_tx_bytes,cta_layout_vmnk=cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)),defer_sync=True).make_participants()
|
||||
kvp,kvc = pipeline.PipelineTmaUmma.create(barrier_storage=st.kv_bar.data_ptr(),num_stages=self.kv_stage,producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread,1),tx_count=self.kv_tx_bytes,cta_layout_vmnk=cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)),defer_sync=True).make_participants()
|
||||
|
||||
tmem_bar = pipeline.NamedBarrier(barrier_id=2,num_threads=32*len((self.mma_warp_id,*self.epilogue_warp_id)))
|
||||
tmem = utils.TmemAllocator(st.holding.ptr,barrier_for_retrieve=tmem_bar,allocator_warp_id=self.epilogue_warp_id[0],is_two_cta=cute.size(qk_mma.thr_id.shape)==2)
|
||||
pipeline.pipeline_init_arrive(cluster_shape_mn=cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)),is_relaxed=True)
|
||||
|
||||
sQ = smem.allocate_tensor(element_type=self.q_dtype,layout=self.q_smem_s.outer,byte_alignment=128,swizzle=self.q_smem_s.inner)
|
||||
sK = smem.allocate_tensor(element_type=self.q_dtype,layout=self.k_smem_s.outer,byte_alignment=128,swizzle=self.k_smem_s.inner)
|
||||
|
||||
gQ = cute.local_tile(mQ,cute.slice_(self.qk_mma_tiler,(None,0,None)),(None,None,None))
|
||||
gK = cute.local_tile(mK,cute.slice_(self.qk_mma_tiler,(0,None,None)),(None,None,None))
|
||||
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK)
|
||||
a_lay = cute.make_layout((1,))
|
||||
b_lay = cute.make_layout((1,))
|
||||
tAsQ,tAgQ = cpasync.tma_partition(tma_q,0,a_lay,cute.group_modes(sQ,0,3),cute.group_modes(tCgQ,0,3))
|
||||
tBsK,tBgK = cpasync.tma_partition(tma_k,0,b_lay,cute.group_modes(sK,0,3),cute.group_modes(tCgK,0,3))
|
||||
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]
|
||||
|
||||
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
|
||||
qk_as = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_as)
|
||||
tStS0 = cute.make_tensor(tStS.iterator, tStS.layout)
|
||||
|
||||
pipeline.pipeline_init_wait(cluster_shape_mn=cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,)))
|
||||
|
||||
# ===== TMA LOAD warp =====
|
||||
if warp_idx == self.tma_warp_id:
|
||||
qp.reset()
|
||||
kvp.reset()
|
||||
# k_sub=0
|
||||
qh0 = qp.acquire_and_advance()
|
||||
cute.copy(tma_q, tAgQ[(None, cutlass.Int32(0))], tAsQ[(None, qh0.index)], tma_bar_ptr=qh0.barrier)
|
||||
kvh0 = kvp.acquire_and_advance()
|
||||
cute.copy(tma_k, tBgK[(None, cutlass.Int32(0))], tBsK[(None, kvh0.index)], tma_bar_ptr=kvh0.barrier)
|
||||
# k_sub=1
|
||||
qh1 = qp.acquire_and_advance()
|
||||
cute.copy(tma_q, tAgQ[(None, cutlass.Int32(1))], tAsQ[(None, qh1.index)], tma_bar_ptr=qh1.barrier)
|
||||
kvh1 = kvp.acquire_and_advance()
|
||||
cute.copy(tma_k, tBgK[(None, cutlass.Int32(1))], tBsK[(None, kvh1.index)], tma_bar_ptr=kvh1.barrier)
|
||||
qp.tail()
|
||||
kvp.tail()
|
||||
|
||||
# ===== MMA warp =====
|
||||
if warp_idx == self.mma_warp_id:
|
||||
tmem.wait_for_alloc()
|
||||
# k_sub=0
|
||||
qh0 = qc.wait_and_advance(); qh0.release()
|
||||
kvh0 = kvc.wait_and_advance()
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
|
||||
for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh0.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
kvh0.release()
|
||||
# k_sub=1
|
||||
qh1 = qc.wait_and_advance(); qh1.release()
|
||||
kvh1 = kvc.wait_and_advance()
|
||||
for kb in cutlass.range(cute.size(tCrQ, mode=[2]), unroll_full=True):
|
||||
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,0)], tCrK[(None,None,kb,kvh1.index)], tStS0)
|
||||
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
|
||||
kvh1.release()
|
||||
cute.arch.fence_view_async_tmem_store()
|
||||
|
||||
# Epilogue warps just allocate/free TMEM
|
||||
if warp_idx < self.mma_warp_id:
|
||||
tmem.allocate(64)
|
||||
tmem.wait_for_alloc()
|
||||
tmem.relinquish_alloc_permit()
|
||||
tmem.free(tmem.retrieve_ptr(self.qk_acc_dtype))
|
||||
|
||||
|
||||
def test():
|
||||
torch.manual_seed(42)
|
||||
hd, n, m = 512, 128, 128
|
||||
|
||||
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
k = torch.randn(n, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
s_out = torch.zeros(1, dtype=torch.float32, device='cuda') # dummy
|
||||
|
||||
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
||||
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
||||
mS = ct.from_dlpack(s_out).mark_layout_dynamic(leading_dim=ct.get_leading_dim(s_out))
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
kernel = QkOnly512()
|
||||
print('Compiling QK-only hd=512...', flush=True)
|
||||
t0 = time.time()
|
||||
compiled = cute.compile(kernel, mQ, mK, mS, stream)
|
||||
t1 = time.time()
|
||||
print(f'Compilation took {t1-t0:.1f}s', flush=True)
|
||||
|
||||
compiled(mQ, mK, mS, stream)
|
||||
torch.cuda.synchronize()
|
||||
print('QK-only hd=512: SUCCESS')
|
||||
|
||||
if __name__ == '__main__':
|
||||
test()
|
||||
Reference in New Issue
Block a user