Files
nvfp4-megamoe-kernel/tests/test_diag_smem_layout.py
2026-05-21 05:08:57 +00:00

73 lines
3.6 KiB
Python

"""Print V SMEM layouts for (128,64) and (128,128) PV. Must run inside JIT."""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05, OperandMajorMode
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
import cutlass.torch as ct
class SmemLayoutKernel:
def __init__(self):
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
self.qk_acc_dtype = Float32
self.use_2cta_instrs = False; self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192; self.num_c_stage = 2
@cute.jit
def __call__(self, q, k, v, c, stream):
self.q_dtype = q.element_type; self.o_dtype = c.element_type
a_major = LayoutEnum.from_tensor(q).mma_major_mode()
b_major = LayoutEnum.from_tensor(k).mma_major_mode()
v_major = LayoutEnum.from_tensor(v).mma_major_mode()
c_layout = LayoutEnum.from_tensor(c)
# QK
qk_mma = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, a_major, b_major,
Float32, tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM)
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
qk_mma_tiler = (128, 128, qk_inst_k * 4)
b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, qk_mma_tiler, BFloat16, 1)
print(f"QK B SMEM: outer={b_smem_s.outer}, inner={b_smem_s.inner}")
# PV (128, 64)
pv_mma_64 = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, OperandMajorMode.K, v_major,
Float32, tcgen05.CtaGroup.ONE, (128, 64), tcgen05.OperandSource.TMEM)
pv_mma_tiler_64 = (128, 64, 128)
v_smem_64 = utils.sm100.make_smem_layout_b(pv_mma_64, pv_mma_tiler_64, BFloat16, 1)
print(f"PV(128,64) V SMEM: outer={v_smem_64.outer}, inner={v_smem_64.inner}")
# PV (128, 128)
pv_mma_128 = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, OperandMajorMode.K, v_major,
Float32, tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.TMEM)
pv_mma_tiler_128 = (128, 128, 128)
v_smem_128 = utils.sm100.make_smem_layout_b(pv_mma_128, pv_mma_tiler_128, BFloat16, 1)
print(f"PV(128,128) V SMEM: outer={v_smem_128.outer}, inner={v_smem_128.inner}")
# Also print the PV MMA atom shapes
print(f"PV(128,64) MMA shape_mnk={pv_mma_64.shape_mnk}")
print(f"PV(128,128) MMA shape_mnk={pv_mma_128.shape_mnk}")
torch.manual_seed(42)
m, n, head_dim = 128, 128, 64
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(n, head_dim, 1, dtype=torch.bfloat16, device='cuda')
v_data = torch.zeros(head_dim, n, dtype=torch.bfloat16, device='cuda')
v_data[0, 0] = 1.0
v = v_data.as_strided((head_dim, n), (1, head_dim)).unsqueeze(-1)
c = torch.zeros(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = SmemLayoutKernel()
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)