Files
nvfp4-megamoe-kernel/tests/test_tmem_col_offset.py

389 lines
21 KiB
Python

"""
Minimal PV-only test: Load P from GMEM to TMEM via QK-style MMA, then PV from TMEM.
Step 1: QK MMA writes FP32 S to TMEM (we know this works)
Step 2: Softmax packing writes BF16 P to TMEM (test this)
Step 3: PV MMA reads BF16 P from TMEM and V from SMEM, produces O
But to isolate the bug, let me test just the PV MMA in isolation.
I'll write known BF16 values to TMEM using the softmax packing path,
then immediately read them back using the PV A-fragment path,
and compare.
Actually, the simplest isolation test:
1. Do QK MMA to get S in TMEM (cosine 0.999999 verified)
2. Do softmax packing: S → P in TMEM (at offset 32)
3. Skip PV entirely — read P from TMEM using the C-fragment composition LOAD path
4. Output P to GMEM and compare against S.to(BF16)
This tests whether the softmax packing writes P correctly to the same TMEM
that the PV would read from.
But we can't easily read P from TMEM using the standard epilogue path
because the epilogue expects FP32 accumulator data.
Alternative: Use the PV MMA with V=I (identity). If P is correct,
then P @ I = P. But V needs to be MN-major and (128, 128), not (128, 64).
The output would be (128, 128) which doesn't match our (128, 64) c tensor.
Let me use V that selects the first 64 columns: V[k, n] = delta(k, n) for k in [0,63].
This gives P @ V = P[:, :64], and the output is (128, 64).
But V is (128, 128) in the MMA K,N dims. V[k, n] for k in [0,127], n in [0,63].
Hmm, this is getting complicated. Let me just do the identity approach with a (128, 128) output.
"""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
import cutlass.torch as ct
class TmemColOffsetKernel:
"""QK + softmax packing + PV with V=I to isolate PV MMA correctness.
Output should be P = S.to(BF16), i.e. (Q@K^T).bfloat16()
With V=I, O = P @ I = P.
But V is (K=128, N=128) in the MMA. We need a 128x128 identity in MN-major.
Output tensor is (128, 128).
"""
def __init__(self, mma_tiler_mn):
self.acc_dtype = Float32; self.qk_acc_dtype = Float32
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
self.mma_tiler_mn = mma_tiler_mn; self.mma_tiler = (*mma_tiler_mn, 1)
self.use_2cta_instrs = False # needed by epilogue_tma_store
self.epilog_sync_bar_id = 1 # needed by epilogue_tma_store
self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192
self.num_c_stage = 2
def _setup(self, qk_mma, pv_mma):
qk_inst_k = int(cute.size(qk_mma.shape_mnk, mode=[2]))
self.qk_mma_tiler = (*self.mma_tiler_mn, qk_inst_k * 4)
# PV with V=I: output is (128, 128), same as QK
self.pv_mma_tiler = (self.qk_mma_tiler[0], self.qk_mma_tiler[1], self.qk_mma_tiler[1])
# pv_mma_tiler = (128, 128, 128) since V is 128x128
self.mma_tiler = self.qk_mma_tiler
self.cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
self.cta_tile_shape_mnk = (
self.qk_mma_tiler[0] // cute.size(qk_mma.thr_id.shape),
self.qk_mma_tiler[1], self.qk_mma_tiler[2])
self.c_layout = LayoutEnum.ROW_MAJOR
self.epi_tile = utils.sm100.compute_epilogue_tile_shape(
self.cta_tile_shape_mnk, False, self.c_layout, self.o_dtype)
self.num_ab_stage = 1; self.num_acc_stage = 1
self.a_smem_s = utils.sm100.make_smem_layout_a(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, self.mma_tiler, self.q_dtype, 1)
self.v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, self.pv_mma_tiler, self.q_dtype, 1)
self.c_smem_s = utils.sm100.make_smem_layout_epi(self.o_dtype, self.c_layout, self.epi_tile, 2)
qk_thr = qk_mma.get_slice(0)
qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
s_cols = find_tmem_tensor_col_offset(tStS)
pv_thr = pv_mma.get_slice(0)
pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
o_cols = find_tmem_tensor_col_offset(tOtO)
self.tilePlikeFP32 = self.qk_mma_tiler[1] // Float32.width * self.o_dtype.width
self.tmem_s0_offset = 0
self.tmem_p0_offset = 32
self.tmem_o0_offset = s_cols
# Compare TMEM column offsets: QK C-fragment S vs PV A-fragment P
tP = cute.make_tensor(tStS.iterator, self.p_tmem_s.outer)
tOrP2 = pv_mma.get_slice(0).make_fragment_A(tP)
pv_col = find_tmem_tensor_col_offset(tOrP2)
print(f"TMEM col: S_base={s_cols}, P_Afrag_at_S_base={pv_col}, O_base={o_cols}")
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
self.num_tmem_alloc_cols = utils.get_num_tmem_alloc_cols([tCtS_fake, tCtO_fake], arch="sm_100")
# ⛔⛔⛔ CRITICAL: num_tma_load_bytes MUST include ALL TMA-loaded tensors (Q + K + V). Missing V → DEADLOCK. See FOOTGUN #0 in README.
a_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
b_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0))
self.num_tma_load_bytes = (
cute.size_in_bytes(self.q_dtype, a_smem) + cute.size_in_bytes(self.q_dtype, b_smem) +
cute.size_in_bytes(self.q_dtype, v_smem)
) * cute.size(qk_mma.thr_id.shape)
@cute.jit
def __call__(self, q, k, v, c, stream):
self.q_dtype = q.element_type; self.o_dtype = c.element_type; self.c_dtype = self.o_dtype
self.a_major = LayoutEnum.from_tensor(q).mma_major_mode()
self.b_major = LayoutEnum.from_tensor(k).mma_major_mode()
self.v_major = LayoutEnum.from_tensor(v).mma_major_mode()
self.c_layout = LayoutEnum.from_tensor(c)
qk_mma = utils.sm100.make_trivial_tiled_mma(
self.q_dtype, self.q_dtype, self.a_major, self.b_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.SMEM)
# PV with 128x128 output (V=I)
pv_mma = utils.sm100.make_trivial_tiled_mma(
self.q_dtype, self.q_dtype, cute.nvgpu.OperandMajorMode.K, self.v_major,
self.qk_acc_dtype, self.cta_group, self.mma_tiler_mn, tcgen05.OperandSource.TMEM)
self._setup(qk_mma, pv_mma)
q_smem = cute.slice_(self.a_smem_s, (None, None, None, 0))
k_smem = cute.slice_(self.b_smem_s, (None, None, None, 0))
v_smem = cute.slice_(self.v_smem_s, (None, None, None, 0))
tma_q, tma_tq = cute.nvgpu.make_tiled_tma_atom_A(
utils.sm100.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, qk_mma.thr_id),
q, q_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_k, tma_tk = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, qk_mma.thr_id),
k, k_smem, self.mma_tiler, qk_mma, self.cluster_layout_vmnk.shape)
tma_v, tma_tv = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, pv_mma.thr_id),
v, v_smem, self.pv_mma_tiler, pv_mma, self.cluster_layout_vmnk.shape)
epi_smem = cute.select(self.c_smem_s, mode=[0, 1])
tma_c, tma_tc = cpasync.make_tiled_tma_atom(cpasync.CopyBulkTensorTileS2GOp(), c, epi_smem, self.epi_tile)
self._kernel(qk_mma, pv_mma, tma_q, tma_tq, tma_k, tma_tk, tma_v, tma_tv,
tma_c, tma_tc, self.cluster_layout_vmnk,
self.a_smem_s, self.b_smem_s, self.v_smem_s, self.p_tmem_s, self.c_smem_s, self.epi_tile
).launch(grid=(1,1,1), block=[self.threads_per_cta,1,1], stream=stream)
@cute.kernel
def _kernel(self, qk_mma, pv_mma, tma_q, mQ, tma_k, mK, tma_v, mV,
tma_c, mC, cl_vmnk, a_smem_s, b_smem_s, v_smem_s, p_tmem_s, c_smem_s, epi_tile):
warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx())
tidx, _, _ = cute.arch.thread_idx()
use_2cta = cute.size(qk_mma.thr_id.shape) == 2
if warp_idx == self.tma_warp_id:
cpasync.prefetch_descriptor(tma_q); cpasync.prefetch_descriptor(tma_k)
cpasync.prefetch_descriptor(tma_v); cpasync.prefetch_descriptor(tma_c)
@cute.struct
class SS:
ab_bar: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2]
mma_si_bar: cute.struct.MemRange[cutlass.Int64, 2]
acc_bar: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2]
tmem_dealloc: cutlass.Int64
holding: cutlass.Int32
smem = utils.SmemAllocator(); st = smem.allocate(SS)
ab_p, ab_c = pipeline.PipelineTmaUmma.create(
barrier_storage=st.ab_bar.data_ptr(), num_stages=self.num_ab_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1),
tx_count=self.num_tma_load_bytes, cta_layout_vmnk=cl_vmnk, defer_sync=True
).make_participants()
mma_si_prod, mma_si_cons = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.mma_si_bar.data_ptr(), num_stages=1,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id)),
).make_participants()
acc_pipe = pipeline.PipelineUmmaAsync.create(
barrier_storage=st.acc_bar.data_ptr(), num_stages=self.num_acc_stage,
producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread),
consumer_group=pipeline.CooperativeGroup(
pipeline.Agent.Thread, len(self.epilogue_warp_id) * (2 if use_2cta else 1)),
cta_layout_vmnk=cl_vmnk, defer_sync=True)
tmem_bar = pipeline.NamedBarrier(barrier_id=2,
num_threads=32 * len((self.mma_warp_id, *self.epilogue_warp_id)))
tmem = utils.TmemAllocator(st.holding.ptr, barrier_for_retrieve=tmem_bar,
allocator_warp_id=self.epilogue_warp_id[0], is_two_cta=use_2cta,
two_cta_tmem_dealloc_mbar_ptr=st.tmem_dealloc.ptr)
pipeline.pipeline_init_arrive(cluster_shape_mn=cl_vmnk, is_relaxed=True)
sQ = smem.allocate_tensor(element_type=self.q_dtype, layout=a_smem_s.outer, byte_alignment=128, swizzle=a_smem_s.inner)
sK = smem.allocate_tensor(element_type=self.q_dtype, layout=b_smem_s.outer, byte_alignment=128, swizzle=b_smem_s.inner)
sV = smem.allocate_tensor(element_type=self.q_dtype, layout=v_smem_s.outer, byte_alignment=128, swizzle=v_smem_s.inner)
sC = smem.allocate_tensor(element_type=self.o_dtype, layout=c_smem_s.outer, byte_alignment=128, swizzle=c_smem_s.inner)
gQ = cute.local_tile(mQ, cute.slice_(self.qk_mma_tiler, (None,0,None)), (None,None,None))
gK = cute.local_tile(mK, cute.slice_(self.qk_mma_tiler, (0,None,None)), (None,None,None))
gC = cute.local_tile(mC, cute.slice_(self.qk_mma_tiler, (None,None,0)), (None,None,None))
k_cnt = cute.size(gQ, mode=[3])
qk_thr = qk_mma.get_slice(0)
pv_thr = pv_mma.get_slice(0)
tCgQ = qk_thr.partition_A(gQ); tCgK = qk_thr.partition_B(gK); tCgC = qk_thr.partition_C(gC)
a_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,0,None,0)).shape)
tAsQ, tAgQ = cpasync.tma_partition(tma_q, 0, a_lay, cute.group_modes(sQ,0,3), cute.group_modes(tCgQ,0,3))
b_lay = cute.make_layout(cute.slice_(cl_vmnk, (0,None,0,0)).shape)
tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay, cute.group_modes(sK,0,3), cute.group_modes(tCgK,0,3))
tAgQ = tAgQ[(None,0,None,0)]; tBgK = tBgK[(None,0,None,0)]
gV = cute.local_tile(mV, cute.slice_(self.pv_mma_tiler, (0,None,None)), (None,None,None))
tCgV = pv_thr.partition_B(gV)
tVsV, tVgV = cpasync.tma_partition(tma_v, 0, b_lay, cute.group_modes(sV,0,3), cute.group_modes(tCgV,0,3))
tVgV = tVgV[(None,0,None,0)]
tCrQ = qk_mma.make_fragment_A(sQ); tCrK = qk_mma.make_fragment_B(sK)
tCrV = pv_mma.make_fragment_B(sV)
qk_acc_shape = qk_thr.partition_shape_C(self.qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_acc_shape)
tStS0 = cute.make_tensor(tStS.iterator + self.tmem_s0_offset, tStS.layout)
pv_acc_shape = pv_thr.partition_shape_C(self.pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_acc_shape)
tOtO0 = cute.make_tensor(tOtO.iterator + self.tmem_o0_offset, tOtO.layout)
tP = cute.make_tensor(tStS.iterator, self.p_tmem_s.outer)
tOrP_base = pv_thr.make_fragment_A(tP)
tOrP = tOrP_base[(None, None, None, 0)]
tOrP0 = cute.make_tensor(
tOrP.iterator + self.qk_acc_dtype.width // self.q_dtype.width * self.tmem_p0_offset,
tOrP.layout)
tCtS_fake = qk_mma.make_fragment_C(cute.append(qk_acc_shape, self.num_acc_stage))
tCtO_fake = pv_mma.make_fragment_C(cute.append(pv_acc_shape, self.num_acc_stage))
pipeline.pipeline_init_wait(cluster_shape_mn=cl_vmnk)
# ═══ TMA LOAD WARP ═══
if warp_idx == self.tma_warp_id:
ab_p.reset(); peek = ab_p.try_acquire()
for kt in cutlass.range(k_cnt, unroll=1):
h = ab_p.acquire_and_advance(peek)
cute.copy(tma_q, tAgQ[(None,h.count)], tAsQ[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_k, tBgK[(None,h.count)], tBsK[(None,h.index)], tma_bar_ptr=h.barrier)
cute.copy(tma_v, tVgV[(None,h.count)], tVsV[(None,h.index)], tma_bar_ptr=h.barrier)
peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_p.try_acquire()
ab_p.tail()
# ═══ MMA WARP ═══
if warp_idx == self.mma_warp_id:
tmem.wait_for_alloc()
ab_c.reset(); peek = ab_c.try_wait()
s0_handle = mma_si_prod.acquire_and_advance()
acc_prod_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, self.num_acc_stage)
acc_pipe.producer_acquire(acc_prod_st)
qk_mma.set(tcgen05.Field.ACCUMULATE, False)
for kt in range(k_cnt):
h = ab_c.wait_and_advance(peek)
nblk = cute.size(tCrQ, mode=[2])
for kb in cutlass.range(nblk, unroll_full=True):
cute.gemm(qk_mma, tStS0, tCrQ[(None,None,kb,h.index)], tCrK[(None,None,kb,h.index)], tStS0)
qk_mma.set(tcgen05.Field.ACCUMULATE, True)
h.release(); peek = cutlass.Boolean(1)
if h.count+1<k_cnt: peek = ab_c.try_wait()
cute.arch.fence_view_async_tmem_store()
s0_handle.commit()
s0_handle = mma_si_prod.acquire_and_advance()
# PV MMA: P @ V where V=I → O = P
pv_mma.set(tcgen05.Field.ACCUMULATE, False)
tCrV_s = tCrV[(None, None, None, 0)]
nblk_pv = cute.size(tOrP0, mode=[2])
for kb in cutlass.range(nblk_pv, unroll_full=True):
cute.gemm(pv_mma, tOtO0, tOrP0[(None,None,kb)], tCrV_s[(None,None,kb)], tOtO0)
pv_mma.set(tcgen05.Field.ACCUMULATE, True)
acc_pipe.producer_commit(acc_prod_st)
acc_prod_st.advance()
acc_pipe.producer_tail(acc_prod_st)
# ═══ EPILOGUE WARPS ═══
if warp_idx < self.mma_warp_id:
tmem.allocate(self.num_tmem_alloc_cols)
tmem.wait_for_alloc()
tmem_ptr = tmem.retrieve_ptr(self.qk_acc_dtype)
sfw_idx = tidx % (32 * len(self.epilogue_warp_id))
tmem_load_atom = cute.make_copy_atom(
tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS0)
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS0)
cS = cute.make_identity_tensor((self.qk_mma_tiler[0], self.qk_mma_tiler[1]))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
tStS_P_layout = cute.composition(tStS.layout, cute.make_layout((128, self.tilePlikeFP32)))
tStS_P = cute.make_tensor(tStS.iterator + self.tmem_p0_offset, tStS_P_layout)
tmem_store_atom = cute.make_copy_atom(
tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), self.qk_acc_dtype)
tiled_tmem_store = tcgen05.make_tmem_copy(tmem_store_atom, tStS_P)
thr_store = tiled_tmem_store.get_slice(sfw_idx)
tTMEM_STOREtS_x4 = thr_store.partition_D(tStS_P)
tScS_P_layout = cute.composition(tScS.layout, cute.make_layout((128, self.tilePlikeFP32)))
tScS_P = cute.make_tensor(tScS.iterator, tScS_P_layout)
tTMEM_STOREcS = thr_store.partition_S(tScS_P)
si_handle = mma_si_cons.wait_and_advance()
tTMEM_LOADrS = cute.make_rmem_tensor(tTMEM_LOADcS.shape, self.qk_acc_dtype)
cute.copy(tiled_tmem_load, tTMEM_LOADtS, tTMEM_LOADrS)
tTMEM_STORErS_x4 = cute.make_rmem_tensor(tTMEM_STOREcS.shape, self.qk_acc_dtype)
tTMEM_STORErS_x4_e = cute.make_tensor(
cute.recast_ptr(tTMEM_STORErS_x4.iterator, dtype=self.q_dtype),
tTMEM_LOADrS.layout)
frg_cnt = 4
frg_tile = cute.size(tTMEM_LOADrS) // frg_cnt
tTMEM_LOADrS_frg = cute.logical_divide(tTMEM_LOADrS, cute.make_layout(frg_tile))
tTMEM_STORErS_x4_e_frg = cute.logical_divide(
tTMEM_STORErS_x4_e, cute.make_layout(frg_tile))
for j in range(frg_cnt):
s_vec = tTMEM_LOADrS_frg[None, j].load()
tTMEM_STORErS_x4_e_frg[None, j].store(s_vec.to(self.q_dtype))
cute.copy(tiled_tmem_store, tTMEM_STORErS_x4, tTMEM_STOREtS_x4)
cute.arch.fence_view_async_tmem_store()
si_handle.release()
# Output epilogue
tCtO_base = cute.make_tensor(tmem_ptr + self.tmem_o0_offset, tCtO_fake.layout)
acc_cons_st = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, self.num_acc_stage)
c_grp = pipeline.CooperativeGroup(pipeline.Agent.Thread, 32 * len(self.epilogue_warp_id))
c_pipe = pipeline.PipelineTmaStore.create(num_stages=self.num_c_stage, producer_group=c_grp)
acc_cons_st = utils.gemm.sm100.epilogue_tma_store(
self, tidx, warp_idx, tma_c, tCtO_base, sC, tCgC,
epi_tile, 0, const_expr(lambda x: x), (0,0,0), acc_cons_st, acc_pipe, c_pipe)
c_pipe.producer_tail()
tmem.relinquish_alloc_permit()
tmem.free(tmem_ptr)
def test():
torch.manual_seed(42)
m, n, head_dim = 128, 128, 64
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(n, head_dim, 1, dtype=torch.bfloat16, device='cuda')
# V = identity (128x128) in MN-major: (128,128) with strides (1,128)
v = torch.eye(128, dtype=torch.bfloat16, device='cuda')
# MN-major: (128,128) with strides (1,128) — row is fast dim
v = v.as_strided((128, 128), (1, 128)).unsqueeze(-1)
c = torch.zeros(m, n, 1, dtype=torch.bfloat16, device='cuda')
qf = q[:,:,0].float(); kf = k[:,:,0].float()
# With V=I and identity softmax: O = (Q@K^T).bf16() @ I = (Q@K^T).bf16()
ref = (qf @ kf.T).bfloat16().float()
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = TmemColOffsetKernel(mma_tiler_mn=(128, 128))
print('Compiling...', flush=True)
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)
print('Running...', flush=True)
compiled(mQ, mK, mV, mC, stream)
torch.cuda.synchronize()
out = c[:,:,0].float()
cos = torch.nn.functional.cosine_similarity(out.flatten().unsqueeze(0), ref.flatten().unsqueeze(0)).item()
print('TmemColOffset(128,128): cosine {:.6f} {}'.format(cos, 'PASS' if cos >= 0.99 else 'FAIL'))
if __name__ == '__main__':
test()