more stuff

This commit is contained in:
2026-05-21 05:08:57 +00:00
parent 1314a67135
commit 1f6fb856ea
6 changed files with 1330 additions and 0 deletions

373
tests/test_diag_layout.py Normal file
View File

@@ -0,0 +1,373 @@
"""
Diagnostic: PV with (128,64) output.
Key fix: compute epilogue tile from PV cta_tile_shape, not QK.
V[d,k] = (d+1)*(k+1), MN-major. Check element-level patterns.
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05, OperandMajorMode
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
import cutlass.torch as ct
class DiagLayoutKernel:
def __init__(self, mma_tiler_mn, head_dim):
self.head_dim = head_dim
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.use_2cta_instrs = False
self.epilog_sync_bar_id = 1
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192
self.num_c_stage = 2
def _setup(self, qk_mma, pv_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1])
self.mma_tiler = self.qk_mma_tiler
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
# QK cta tile
self.qk_cta_tile_shape_mnk = (
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
self.qk_mma_tiler[1], self.qk_mma_tiler[2])
# PV cta tile — for epilogue, this is what matters
self.pv_cta_tile_shape_mnk = (
self.pv_mma_tiler[0] // cute.size(pv_mma.thr_id.shape),
self.pv_mma_tiler[1], self.pv_mma_tiler[2])
self.c_layout = LayoutEnum.ROW_MAJOR
# Compute epi_tile from PV cta_tile, not QK
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.pv_cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
print(f"[SETUP] qk_mma_tiler={self.qk_mma_tiler}, pv_mma_tiler={self.pv_mma_tiler}")
print(f"[SETUP] qk_cta_tile={self.qk_cta_tile_shape_mnk}, pv_cta_tile={self.pv_cta_tile_shape_mnk}")
print(f"[SETUP] epi_tile={self.epi_tile}")
self.cta_tile_shape_mnk = self.pv_cta_tile_shape_mnk
self.num_ab_stage = 1; self.num_acc_stage = 1
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
qk_thr = qk_mma.get_slice(0)
qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
s_cols = find_tmem_tensor_col_offset(tStS)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
o_cols = find_tmem_tensor_col_offset(tOtO)
self.tilePlikeFP32 = self.qk_mma_tiler[1] // Float32.width * self.o_dtype.width
self.tmem_s0_offset = 0
self.tmem_p0_offset = 32
self.tmem_o0_offset = s_cols
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem) +
cute.size_in_bytes(self.q_dtype, v_smem)
) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, q, k, v, c, stream):
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
self.v_major = LayoutEnum.from_tensor(v).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.q_dtype, self.q_dtype, self.a_major, self.b_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
pv_mma_tiler_mn = (self.mma_tiler_mn[0], self.head_dim)
pv_mma = utils.sm100.make_trivial_tiled_mma(
self.q_dtype, self.q_dtype, OperandMajorMode.K, self.v_major,
self.qk_acc_dtype, self.cta_group, pv_mma_tiler_mn, tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
q_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
k_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0))
tma_q, tma_tq = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
q, q_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_k, tma_tk = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
k, k_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_v, tma_tv = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id),
v, v_smem, self.pv_mma_tiler, pv_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, pv_mma, tma_q, tma_tq, tma_k, tma_tk, tma_v, tma_tv,
tma_c, tma_tc, self.cluster_layout_vmnk,
self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV,
tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k)
cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sQ = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sK = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
sV = smem.allocate_tensor(element_type=self.q_dtype, layout=v_smem_s.outer, byte_alignment=128, swizzle=v_smem_s.inner)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gQ = cute.local_tile(mQ, cute.slice_(self.qk_mma_tiler, (None,0,None)), (None,None,None))
gK = cute.local_tile(mK, cute.slice_(self.qk_mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.pv_mma_tiler, (None,0,None)), (None,None,None))
k_cnt = cute.size(gQ, mode=[3])
qk_thr = qk_mma.get_slice(0)
pv_thr = pv_mma.get_slice(0)
tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK); tCgC = pv_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsQ, tAgQ = cpasync.tma_partition(tma_q, 0, a_lay, cute.group_modes(sQ,0,3), cute.group_modes(tCgQ,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay, cute.group_modes(sK,0,3), cute.group_modes(tCgK,0,3))
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]
gV = cute.local_tile(mV, cute.slice_(self.pv_mma_tiler, (0,None,None)), (None,None,None))
tCgV = pv_thr.partition_B(gV)
tVsV, tVgV = cpasync.tma_partition(tma_v, 0, b_lay, cute.group_modes(sV,0,3), cute.group_modes(tCgV,0,3))
tVgV = tVgV[(None,0,None,0)]
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sV)
qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP_base = pv_thr.make_fragment_A(tP)
tOrP = tOrP_base[(None, None, None, 0)]
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout)
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# TMA LOAD WARP
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_q, tAgQ[(None,h.count)], tAsQ[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_k, tBgK[(None,h.count)], tBsK[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_v, tVgV[(None,h.count)], tVsV[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# MMA WARP
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
ab_c.reset(); peek = ab_c.try_wait()
s0_handle = mma_si_prod.acquire_and_advance()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st)
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrQ, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,h.index)], tCrK[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store()
s0_handle.commit()
s0_handle = mma_si_prod.acquire_and_advance()
pv_mma.set(tcgen05.Field.ACCUMULATE, False)
tCrV_s = tCrV[(None, None, None, 0)]
nblk_pv = cute.size(tOrP0, mode=[2])
for kb in cutlass.range(nblk_pv, unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
acc_pipe.producer_commit(acc_prod_st)
acc_prod_st.advance()
acc_pipe.producer_tail(acc_prod_st)
# EPILOGUE WARPS
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS0)
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, self.tilePlikeFP32)))
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
tmem_store_atom = cute.make_copy_atom(
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
thr_store = tiled_tmem_store.get_slice(sfw_idx)
tTMEM_STOREtS_x4 = thr_store.partition_D(tStS_P)
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, self.tilePlikeFP32)))
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
tTMEM_STOREcS = thr_store.partition_S(tScS_P)
si_handle = mma_si_cons.wait_and_advance()
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
tTMEM_STORErS_x4_e = cute.make_tensor(
cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype),
tTMEM_LOADrS.layout)
frg_cnt = 4
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
tTMEM_STORErS_x4_e_frg = cute.logical_divide(
tTMEM_STORErS_x4_e, cute.make_layout(frg_tile))
for j in range(frg_cnt):
s_vec = tTMEM_LOADrS_frg[None, j].load()
tTMEM_STORErS_x4_e_frg[None, j].store(s_vec.to(self.q_dtype))
cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4)
cute.arch.fence_view_async_tmem_store()
si_handle.release()
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test_diag_v():
torch.manual_seed(42)
m, n, head_dim = 128, 128, 64
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(n, head_dim, 1, dtype=torch.bfloat16, device='cuda')
v_data = torch.zeros(head_dim, n, dtype=torch.bfloat16, device='cuda')
for d in range(head_dim):
for k_idx in range(n):
v_data[d, k_idx] = (d + 1) * (k_idx + 1)
v = v_data.as_strided((head_dim, n), (1, head_dim)).unsqueeze(-1)
c = torch.zeros(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:,:,0].float()
kf = k[:,:,0].float()
vf = v_data.float()
ref = (qf @ kf.T).bfloat16().float() @ vf.T
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = DiagLayoutKernel(mma_tiler_mn=(128, 128), head_dim=head_dim)
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mV, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
max_err = (out - ref).abs().max().item()
print('Diag V test (128x64 PV, epi from PV cta_tile): cosine {:.6f}, max_err {:.6f} {}'.format(
cos, max_err, 'PASS' if cos >= 0.99 else 'FAIL'))
if cos < 0.99:
print('\n=== Element-level diagnostics ===')
for m_idx in [0, 1, 63, 127]:
for d_idx in [0, 1, 31, 63]:
print(f' O[{m_idx},{d_idx}] = {out[m_idx,d_idx]:.4f}, ref = {ref[m_idx,d_idx]:.4f}')
print(f'\n O[0,:5] = {out[0,:5].tolist()}')
print(f' ref[0,:5] = {ref[0,:5].tolist()}')
print(f' O[:5,0] = {out[:5,0].tolist()}')
print(f' ref[:5,0] = {ref[:5,0].tolist()}')
if __name__ == '__main__':
test_diag_v()

View File

@@ -0,0 +1,80 @@
"""
Quick diagnostic: truncated identity V with 128x64 PV.
Check if output columns match a permutation of reference columns.
If O[m,d] = P[m, perm(d)], then the PV MMA is reading P from wrong TMEM addresses.
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05, OperandMajorMode
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
import cutlass.torch as ct
# Reuse the DiagVTruncIdKernel from test_diag_v_truncid.py
# (just run it and do more analysis on the output)
# Actually, let me just re-run the truncid test and do the permutation analysis in Python
# First run the kernel, then analyze
# We already ran it and have the results. Let me just do the analysis with the numbers we have.
# O[0,:5] = [6.0625, 11.875, -9.5625, -4.6875, -14.9375]
# ref[0,:5] = [6.0625, 10.5625, 11.875, -11.75, -9.5625]
# P[0] (full Q@K^T row 0) needs to be computed
torch.manual_seed(42)
m, n, head_dim = 128, 128, 64
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(n, head_dim, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:,:,0].float()
kf = k[:,:,0].float()
P = (qf @ kf.T) # (128, 128) — the P matrix
# Now check: does O[0, d] = P[0, perm(d)] for some permutation?
# O[0,0] = 6.0625 → matches P[0,0] = 6.0625
# O[0,1] = 11.875 → matches P[0,2] = 11.875
# O[0,2] = -9.5625 → matches P[0,4] = -9.5625
# So O[0, d] = P[0, 2*d]? Let me check more.
O_row0 = [6.0625, 11.875, -9.5625, -4.6875, -14.9375]
P_row0 = P[0, :10].tolist()
print(f"P[0, :10] = {P_row0}")
print(f"O[0, :5] = {O_row0}")
# Check: O[0, d] = P[0, 2*d]?
for d in range(5):
print(f" O[0,{d}] = {O_row0[d]:.4f}, P[0,{2*d}] = {P_row0[2*d]:.4f}, match = {abs(O_row0[d] - P_row0[2*d]) < 0.01}")
# Also check full P row 0 vs O
# We can't get O without running the kernel again, but the pattern is clear:
# O[m, d] = P[m, 2*d] for the truncated identity V case
# This means the PV MMA is reading P from every other TMEM column
# Why 2*d? Because with (128,64) MMA, the A fragment reads TMEM with stride 2 in the K dimension.
# The (128,64,16) MMA atom has N=64, which means it reads 64 columns of P per K-tile
# But P has 128 columns. The MMA reads the first 64, but with the wrong stride.
#
# Actually, with (128,64,16) MMA:
# - A operand: (M=128, K=128) → MMA reads 128/16 = 8 K-tiles
# - Each K-tile reads P[:, k*16:(k+1)*16] = 16 columns of P
# - The A fragment for K-tile kb reads from TMEM column offset based on N_MMA
#
# The (128,64,16) MMA's TMEM A fragment layout might be:
# (128, N_MMA) where N_MMA relates to the N dimension of the MMA
# If N_MMA = 64 (half of 128), then P's 128 BF16 values in K are stored
# in 128 BF16 TMEM columns = 64 FP32 TMEM columns
# But the (128,64,16) A fragment might only address 32 FP32 TMEM columns
# because the MMA only uses 64 columns for the C output
# So P's 128 K values don't fit in 32 TMEM columns, and the layout is different
# The root cause: the (128,64) MMA's A fragment in TMEM packs 128 BF16 K values
# into fewer TMEM columns than the (128,128) MMA. The softmax packing writes P
# using the (128,128) layout, but the PV MMA reads with the (128,64) layout.
print("\n=== HYPOTHESIS ===")
print("The (128,64,16) MMA atom reads P from TMEM with a DIFFERENT layout")
print("than the softmax packing writes P with (QK C fragment layout).")
print("The (128,128,16) MMA atom's A fragment layout matches the QK C fragment layout,")
print("so the 128x128 case works. The (128,64,16) layout differs, causing the bug.")
print("Fix: softmax packing should write P using the PV MMA's A fragment layout.")

View File

@@ -0,0 +1,72 @@
"""Print V SMEM layouts for (128,64) and (128,128) PV. Must run inside JIT."""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05, OperandMajorMode
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
import cutlass.torch as ct
class SmemLayoutKernel:
def __init__(self):
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
self.qk_acc_dtype = Float32
self.use_2cta_instrs = False; self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192; self.num_c_stage = 2
@cute.jit
def __call__(self, q, k, v, c, stream):
self.q_dtype = q.element_type; self.o_dtype = c.element_type
a_major = LayoutEnum.from_tensor(q).mma_major_mode()
b_major = LayoutEnum.from_tensor(k).mma_major_mode()
v_major = LayoutEnum.from_tensor(v).mma_major_mode()
c_layout = LayoutEnum.from_tensor(c)
# QK
qk_mma = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, a_major, b_major,
Float32, tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM)
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
qk_mma_tiler = (128, 128, qk_inst_k * 4)
b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, qk_mma_tiler, BFloat16, 1)
print(f"QK B SMEM: outer={b_smem_s.outer}, inner={b_smem_s.inner}")
# PV (128, 64)
pv_mma_64 = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, OperandMajorMode.K, v_major,
Float32, tcgen05.CtaGroup.ONE, (128, 64), tcgen05.OperandSource.TMEM)
pv_mma_tiler_64 = (128, 64, 128)
v_smem_64 = utils.sm100.make_smem_layout_b(pv_mma_64, pv_mma_tiler_64, BFloat16, 1)
print(f"PV(128,64) V SMEM: outer={v_smem_64.outer}, inner={v_smem_64.inner}")
# PV (128, 128)
pv_mma_128 = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, OperandMajorMode.K, v_major,
Float32, tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.TMEM)
pv_mma_tiler_128 = (128, 128, 128)
v_smem_128 = utils.sm100.make_smem_layout_b(pv_mma_128, pv_mma_tiler_128, BFloat16, 1)
print(f"PV(128,128) V SMEM: outer={v_smem_128.outer}, inner={v_smem_128.inner}")
# Also print the PV MMA atom shapes
print(f"PV(128,64) MMA shape_mnk={pv_mma_64.shape_mnk}")
print(f"PV(128,128) MMA shape_mnk={pv_mma_128.shape_mnk}")
torch.manual_seed(42)
m, n, head_dim = 128, 128, 64
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(n, head_dim, 1, dtype=torch.bfloat16, device='cuda')
v_data = torch.zeros(head_dim, n, dtype=torch.bfloat16, device='cuda')
v_data[0, 0] = 1.0
v = v_data.as_strided((head_dim, n), (1, head_dim)).unsqueeze(-1)
c = torch.zeros(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = SmemLayoutKernel()
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)

374
tests/test_diag_v_mma128.py Normal file
View File

@@ -0,0 +1,374 @@
"""
Diagnostic: PV with (128,64) output but using (128,128) MMA for PV.
This keeps the A fragment (P) TMEM layout the same as the softmax packing path.
V is truncated identity (64,128) MN-major. If cosine ~0.999, confirms TMEM alias mismatch.
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05, OperandMajorMode
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
import cutlass.torch as ct
class DiagVMma128Kernel:
def __init__(self, head_dim):
self.head_dim = head_dim
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
self.mma_tiler_mn = (128, 128); self.mma_tiler = (128, 128, 1)
self.use_2cta_instrs = False
self.epilog_sync_bar_id = 1
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192
self.num_c_stage = 2
def _setup(self, qk_mma, pv_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (128, 128, qk_inst_k * 4)
# PV uses (128, 128) MMA even though output is (128, 64)
self.pv_mma_tiler = (128, 128, 128) # Same as 128x128 case
self.mma_tiler = self.qk_mma_tiler
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.cta_tile_shape_mnk = (
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
self.qk_mma_tiler[1], self.qk_mma_tiler[2])
self.c_layout = LayoutEnum.ROW_MAJOR
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
print(f"[SETUP] qk_mma_tiler={self.qk_mma_tiler}, pv_mma_tiler={self.pv_mma_tiler}")
print(f"[SETUP] epi_tile={self.epi_tile}")
self.num_ab_stage = 1; self.num_acc_stage = 1
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
qk_thr = qk_mma.get_slice(0)
qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
s_cols = find_tmem_column_offset(tStS)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
o_cols = find_tmem_tensor_col_offset(tOtO)
self.tilePlikeFP32 = self.qk_mma_tiler[1] // Float32.width * self.o_dtype.width
self.tmem_s0_offset = 0
self.tmem_p0_offset = 32
self.tmem_o0_offset = s_cols
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem) +
cute.size_in_bytes(self.q_dtype, v_smem)
) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, q, k, v, c, stream):
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
self.v_major = LayoutEnum.from_tensor(v).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.q_dtype, self.q_dtype, self.a_major, self.b_major,
self.qk_acc_dtype, self.cta_group, (128, 128), tcgen05.OperandSource.SMEM)
# PV MMA (128, 128) - same as QK, so A fragment layout matches softmax packing
pv_mma = utils.sm100.make_trivial_tiled_mma(
self.q_dtype, self.q_dtype, OperandMajorMode.K, self.v_major,
self.qk_acc_dtype, self.cta_group, (128, 128), tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
q_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
k_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0))
tma_q, tma_tq = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
q, q_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_k, tma_tk = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
k, k_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_v, tma_tv = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id),
v, v_smem, self.pv_mma_tiler, pv_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, pv_mma, tma_q, tma_tq, tma_k, tma_tk, tma_v, tma_tv,
tma_c, tma_tc, self.cluster_layout_vmnk,
self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV,
tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k)
cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sQ = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sK = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
sV = smem.allocate_tensor(element_type=self.q_dtype, layout=v_smem_s.outer, byte_alignment=128, swizzle=v_smem_s.inner)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gQ = cute.local_tile(mQ, cute.slice_(self.qk_mma_tiler, (None,0,None)), (None,None,None))
gK = cute.local_tile(mK, cute.slice_(self.qk_mma_tiler, (0,None,None)), (None,None,None))
# PV (128,128) output partitioned with pv_thr — c tensor is (128,64) but we write full (128,128)
gC = cute.local_tile(mC, cute.slice_(self.qk_mma_tiler, (None,0,None)), (None,None,None))
k_cnt = cute.size(gQ, mode=[3])
qk_thr = qk_mma.get_slice(0)
pv_thr = pv_mma.get_slice(0)
tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK); tCgC = pv_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsQ, tAgQ = cpasync.tma_partition(tma_q, 0, a_lay, cute.group_modes(sQ,0,3), cute.group_modes(tCgQ,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay, cute.group_modes(sK,0,3), cute.group_modes(tCgK,0,3))
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]
gV = cute.local_tile(mV, cute.slice_(self.pv_mma_tiler, (0,None,None)), (None,None,None))
tCgV = pv_thr.partition_B(gV)
tVsV, tVgV = cpasync.tma_partition(tma_v, 0, b_lay, cute.group_modes(sV,0,3), cute.group_modes(tCgV,0,3))
tVgV = tVgV[(None,0,None,0)]
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sV)
qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP_base = pv_thr.make_fragment_A(tP)
tOrP = tOrP_base[(None, None, None, 0)]
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout)
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# TMA LOAD WARP
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_q, tAgQ[(None,h.count)], tAsQ[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_k, tBgK[(None,h.count)], tBsK[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_v, tVgV[(None,h.count)], tVsV[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# MMA WARP
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
ab_c.reset(); peek = ab_c.try_wait()
s0_handle = mma_si_prod.acquire_and_advance()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st)
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrQ, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,h.index)], tCrK[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store()
s0_handle.commit()
s0_handle = mma_si_prod.acquire_and_advance()
pv_mma.set(tcgen05.Field.ACCUMULATE, False)
tCrV_s = tCrV[(None, None, None, 0)]
nblk_pv = cute.size(tOrP0, mode=[2])
for kb in cutlass.range(nblk_pv, unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
acc_pipe.producer_commit(acc_prod_st)
acc_prod_st.advance()
acc_pipe.producer_tail(acc_prod_st)
# EPILOGUE WARPS
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS0)
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, self.tilePlikeFP32)))
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
tmem_store_atom = cute.make_copy_atom(
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
thr_store = tiled_tmem_store.get_slice(sfw_idx)
tTMEM_STOREtS_x4 = thr_store.partition_D(tStS_P)
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, self.tilePlikeFP32)))
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
tTMEM_STOREcS = thr_store.partition_S(tScS_P)
si_handle = mma_si_cons.wait_and_advance()
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
tTMEM_STORErS_x4_e = cute.make_tensor(
cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype),
tTMEM_LOADrS.layout)
frg_cnt = 4
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
tTMEM_STORErS_x4_e_frg = cute.logical_divide(
tTMEM_STORErS_x4_e, cute.make_layout(frg_cnt))
for j in range(frg_cnt):
s_vec = tTMEM_LOADrS_frg[None, j].load()
tTMEM_STORErS_x4_e_frg[None, j].store(s_vec.to(self.q_dtype))
cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4)
cute.arch.fence_view_async_tmem_store()
si_handle.release()
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m, n, head_dim = 128, 128, 64
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(n, head_dim, 1, dtype=torch.bfloat16, device='cuda')
# V = truncated identity (64, 128) MN-major
v_data = torch.zeros(head_dim, n, dtype=torch.bfloat16, device='cuda')
for d in range(head_dim):
v_data[d, d] = 1.0
v = v_data.as_strided((head_dim, n), (1, head_dim)).unsqueeze(-1)
# Output is (128, 128) even though we only care about first 64 columns
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:,:,0].float()
kf = k[:,:,0].float()
# With V=trunc_id: O = P @ V^T where V^T picks first 64 cols
# But PV MMA is (128,128), so full output is P @ V^T which is (128,128)
# Since V is (64,128) with V[d,k]=delta(d,k), V^T[k,d]=delta(d,k)
# P(128,128) @ V^T(128,64) = (128,64)... but MMA is (128,128)
# Need V(128,128) for (128,128) MMA, pad with zeros
v_full = torch.zeros(n, n, dtype=torch.bfloat16, device='cuda')
for d in range(head_dim):
v_full[d, d] = 1.0
v_full = v_full.as_strided((n, n), (1, n)).unsqueeze(-1)
ref = (qf @ kf.T).bfloat16().float()
# First 64 columns = P[:, :64], last 64 should be 0 (V has zeros there)
ref_64 = ref[:, :head_dim]
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_full).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_full))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = DiagVMma128Kernel(head_dim=head_dim)
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mV, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
# Check first 64 columns
out_64 = out[:, :head_dim]
cos = torch.nn.functional.cosine_similarity(out_64.flatten().unsqueeze(0), ref_64.flatten().unsqueeze(0)).item()
max_err = (out_64 - ref_64).abs().max().item()
print('PV(128,128) MMA + V=trunc_id: first 64 cols cosine {:.6f}, max_err {:.6f} {}'.format(
cos, max_err, 'PASS' if cos >= 0.99 else 'FAIL'))
# Check that last 64 cols are ~0
out_last64 = out[:, head_dim:]
last64_max = out_last64.abs().max().item()
print('Last 64 cols max abs value: {:.6f} (should be ~0)'.format(last64_max))
if __name__ == '__main__':
test()

62
tests/test_diag_v_ones.py Normal file
View File

@@ -0,0 +1,62 @@
"""
Diagnostic: PV with V = all ones (64, 128) MN-major.
O[m, d] should be sum_k P[m, k] for all d (all columns identical).
If columns differ, the PV MMA is reading V incorrectly.
Also test: PV with V = single element (d=0, k=0) = 1, rest 0.
O[m, 0] = P[m, 0], all other entries 0.
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05, OperandMajorMode
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
import cutlass.torch as ct
# Reuse the same kernel as test_diag_v_truncid.py but with different V
def run_test(v_data, desc, head_dim=64):
m, n = 128, 128
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(n, head_dim, 1, dtype=torch.bfloat16, device='cuda')
v = v_data.as_strided((head_dim, n), (1, head_dim)).unsqueeze(-1)
c = torch.zeros(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:,:,0].float(); kf = k[:,:,0].float(); vf = v_data.float()
ref = (qf @ kf.T).bfloat16().float() @ vf.T
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
# Import the kernel from test_diag_v_truncid
from test_diag_v_truncid import DiagVTruncIdKernel
kernel = DiagVTruncIdKernel(mma_tiler_mn=(128, 128), head_dim=head_dim)
print(f'Compiling {desc}...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
print(f'Running {desc}...', flush=True)
compiled(mQ, mK, mV, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print(f'{desc}: cosine {cos:.6f} {"PASS" if cos >= 0.99 else "FAIL"}')
if cos < 0.99:
print(f' O[0,:5] = {out[0,:5].tolist()}')
print(f' ref[0,:5] = {ref[0,:5].tolist()}')
# Check if columns are all the same (for all-ones V)
if v_data.abs().sum() > 0:
col0 = out[:, 0]
all_same = all(torch.allclose(out[:, d], col0, atol=0.1) for d in range(min(8, head_dim)))
print(f' All columns same? {all_same}')
# Test 1: All ones V
v_ones = torch.ones(64, 128, dtype=torch.bfloat16, device='cuda')
run_test(v_ones, "All-ones V")
# Test 2: Single element V
v_single = torch.zeros(64, 128, dtype=torch.bfloat16, device='cuda')
v_single[0, 0] = 1.0
run_test(v_single, "Single-element V")

View File

@@ -0,0 +1,369 @@
"""
Diagnostic: PV with truncated identity V (64,128).
V[d,k] = 1 if d==k, 0 otherwise. MN-major.
With identity softmax: O = P[:, :64] = (Q@K^T).bfloat16()[:, :64].
If PV MMA is correct, cosine should be ~0.999.
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05, OperandMajorMode
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
import cutlass.torch as ct
class DiagVTruncIdKernel:
def __init__(self, mma_tiler_mn, head_dim):
self.head_dim = head_dim
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.use_2cta_instrs = False
self.epilog_sync_bar_id = 1
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192
self.num_c_stage = 2
def _setup(self, qk_mma, pv_mma):
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[2], self.qk_mma_tiler[1])
self.mma_tiler = self.qk_mma_tiler
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.qk_cta_tile_shape_mnk = (
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
self.qk_mma_tiler[1], self.qk_mma_tiler[2])
# PV cta tile for epilogue
self.pv_cta_tile_shape_mnk = (
self.pv_mma_tiler[0] // cute.size(pv_mma.thr_id.shape),
self.pv_mma_tiler[1], self.pv_mma_tiler[2])
self.cta_tile_shape_mnk = self.pv_cta_tile_shape_mnk
self.c_layout = LayoutEnum.ROW_MAJOR
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.pv_cta_tile_shape_mnk, self.use_2cta_instrs, self.c_layout, self.o_dtype)
print(f"[SETUP] qk_mma_tiler={self.qk_mma_tiler}, pv_mma_tiler={self.pv_mma_tiler}")
print(f"[SETUP] epi_tile={self.epi_tile}")
self.num_ab_stage = 1; self.num_acc_stage = 1
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
qk_thr = qk_mma.get_slice(0)
qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
s_cols = find_tmem_tensor_col_offset(tStS)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
o_cols = find_tmem_tensor_col_offset(tOtO)
self.tilePlikeFP32 = self.qk_mma_tiler[1] // Float32.width * self.o_dtype.width
self.tmem_s0_offset = 0
self.tmem_p0_offset = 32
self.tmem_o0_offset = s_cols
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem) +
cute.size_in_bytes(self.q_dtype, v_smem)
) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, q, k, v, c, stream):
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
self.v_major = LayoutEnum.from_tensor(v).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.q_dtype, self.q_dtype, self.a_major, self.b_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
pv_mma_tiler_mn = (self.mma_tiler_mn[0], self.head_dim)
pv_mma = utils.sm100.make_trivial_tiled_mma(
self.q_dtype, self.q_dtype, OperandMajorMode.K, self.v_major,
self.qk_acc_dtype, self.cta_group, pv_mma_tiler_mn, tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
q_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
k_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0))
tma_q, tma_tq = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
q, q_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_k, tma_tk = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
k, k_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_v, tma_tv = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id),
v, v_smem, self.pv_mma_tiler, pv_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, pv_mma, tma_q, tma_tq, tma_k, tma_tk, tma_v, tma_tv,
tma_c, tma_tc, self.cluster_layout_vmnk,
self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV,
tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k)
cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage *2]
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sQ = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sK = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
sV = smem.allocate_tensor(element_type=self.q_dtype, layout=v_smem_s.outer, byte_alignment=128, swizzle=v_smem_s.inner)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gQ = cute.local_tile(mQ, cute.slice_(self.qk_mma_tiler, (None,0,None)), (None,None,None))
gK = cute.local_tile(mK, cute.slice_(self.qk_mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.pv_mma_tiler, (None,0,None)), (None,None,None))
k_cnt = cute.size(gQ, mode=[3])
qk_thr = qk_mma.get_slice(0)
pv_thr = pv_mma.get_slice(0)
tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK); tCgC = pv_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsQ, tAgQ = cpasync.tma_partition(tma_q, 0, a_lay, cute.group_modes(sQ,0,3), cute.group_modes(tCgQ,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay, cute.group_modes(sK,0,3), cute.group_modes(tCgK,0,3))
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]
gV = cute.local_tile(mV, cute.slice_(self.pv_mma_tiler, (0,None,None)), (None,None,None))
tCgV = pv_thr.partition_B(gV)
tVsV, tVgV = cpasync.tma_partition(tma_v, 0, b_lay, cute.group_modes(sV,0,3), cute.group_modes(tCgV,0,3))
tVgV = tVgV[(None,0,None,0)]
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sV)
qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
tP = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tOrP_base = pv_thr.make_fragment_A(tP)
tOrP = tOrP_base[(None, None, None, 0)]
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout)
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# TMA LOAD WARP
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_q, tAgQ[(None,h.count)], tAsQ[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_k, tBgK[(None,h.count)], tBsK[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_v, tVgV[(None,h.count)], tVsV[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# MMA WARP
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
ab_c.reset(); peek = ab_c.try_wait()
s0_handle = mma_si_prod.acquire_and_advance()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st)
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrQ, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,h.index)], tCrK[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store()
s0_handle.commit()
s0_handle = mma_si_prod.acquire_and_advance()
pv_mma.set(tcgen05.Field.ACCUMULATE, False)
tCrV_s = tCrV[(None, None, None, 0)]
nblk_pv = cute.size(tOrP0, mode=[2])
for kb in cutlass.range(nblk_pv, unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
acc_pipe.producer_commit(acc_prod_st)
acc_prod_st.advance()
acc_pipe.producer_tail(acc_prod_st)
# EPILOGUE WARPS
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
# Softmax packing
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS0)
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, self.tilePlikeFP32)))
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
tmem_store_atom = cute.make_copy_atom(
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
thr_store = tiled_tmem_store.get_slice(sfw_idx)
tTMEM_STOREtS_x4 = thr_store.partition_D(tStS_P)
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, self.tilePlikeFP32)))
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
tTMEM_STOREcS = thr_store.partition_S(tScS_P)
si_handle = mma_si_cons.wait_and_advance()
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
tTMEM_STORErS_x4_e = cute.make_tensor(
cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype),
tTMEM_LOADrS.layout)
frg_cnt = 4
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
tTMEM_STORErS_x4_e_frg = cute.logical_divide(
tTMEM_STORErS_x4_e, cute.make_layout(frg_tile))
for j in range(frg_cnt):
s_vec = tTMEM_LOADrS_frg[None, j].load()
tTMEM_STORErS_x4_e_frg[None, j].store(s_vec.to(self.q_dtype))
cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4)
cute.arch.fence_view_async_tmem_store()
si_handle.release()
# Output epilogue
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m, n, head_dim = 128, 128, 64
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(n, head_dim, 1, dtype=torch.bfloat16, device='cuda')
# Truncated identity V: (64, 128) MN-major, V[d,k] = 1 if d==k, 0 otherwise
v_data = torch.zeros(head_dim, n, dtype=torch.bfloat16, device='cuda')
for d in range(head_dim):
v_data[d, d] = 1.0
v = v_data.as_strided((head_dim, n), (1, head_dim)).unsqueeze(-1)
c = torch.zeros(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
# Reference: O = (Q@K^T).bfloat16()[:, :64] since V^T = first 64 cols of identity
qf = q[:,:,0].float()
kf = k[:,:,0].float()
ref = (qf @ kf.T).bfloat16().float()[:, :head_dim]
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = DiagVTruncIdKernel(mma_tiler_mn=(128, 128), head_dim=head_dim)
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mV, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
max_err = (out - ref).abs().max().item()
print('TruncId V (128x64 PV): cosine {:.6f}, max_err {:.6f} {}'.format(
cos, max_err, 'PASS' if cos >= 0.99 else 'FAIL'))
if cos < 0.99:
print('\n=== Element-level ===')
for m_idx in [0, 1, 63, 127]:
for d_idx in [0, 1, 31, 63]:
print(f' O[{m_idx},{d_idx}] = {out[m_idx,d_idx]:.4f}, ref = {ref[m_idx,d_idx]:.4f}')
print(f' O[0,:5] = {out[0,:5].tolist()}')
print(f' ref[0,:5] = {ref[0,:5].tolist()}')
if __name__ == '__main__':
test()