Files
nvfp4-megamoe-kernel/tests/archive/test_diag_smem_layout.py
biondizzle 9cbdc92744 Restructure: cutedsl/ -> dsv4/ with proper layering
- Split bridge.py -> ops/quantize.py, ops/layouts.py, ops/gemm_runner.py
- Renamed classes: CuTeDSLNvfp4Linear -> Nvfp4Linear, etc.
- Moved kernel code to dsv4/kernels/ (gemm, attention, compressor, decode, cuda)
- Moved PyTorch bridges to dsv4/ops/
- Moved nn.Module layers to dsv4layers/
- Moved reference implementations to dsv4/reference/
- Moved vendored CUTLASS code to vendored/
- Archived ~190 debug tests to tests/archive/
- Kept ~15 canonical tests in tests/unit/
- Updated all import paths
- Added stubs for future components (model/, cache/, loader/)
- Updated pyproject.toml: dsv4-inference package name
2026-05-21 17:30:44 +00:00

73 lines
3.6 KiB
Python

"""Print V SMEM layouts for (128,64) and (128,128) PV. Must run inside JIT."""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils, cutlass.pipeline as pipeline
from cutlass.cute.nvgpu import cpasync, tcgen05, OperandMajorMode
from cutlass import Float32, BFloat16, Int32, Boolean, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cuda.bindings.driver as cuda
import cutlass.torch as ct
class SmemLayoutKernel:
def __init__(self):
self.q_dtype = BFloat16; self.o_dtype = BFloat16; self.c_dtype = BFloat16
self.qk_acc_dtype = Float32
self.use_2cta_instrs = False; self.cluster_shape_mn = (1, 1)
self.cta_group = tcgen05.CtaGroup.ONE
self.epilogue_warp_id = (0, 1, 2, 3)
self.mma_warp_id = 4; self.tma_warp_id = 5
self.threads_per_cta = 192; self.num_c_stage = 2
@cute.jit
def __call__(self, q, k, v, c, stream):
self.q_dtype = q.element_type; self.o_dtype = c.element_type
a_major = LayoutEnum.from_tensor(q).mma_major_mode()
b_major = LayoutEnum.from_tensor(k).mma_major_mode()
v_major = LayoutEnum.from_tensor(v).mma_major_mode()
c_layout = LayoutEnum.from_tensor(c)
# QK
qk_mma = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, a_major, b_major,
Float32, tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM)
qk_inst_k = cute.size(qk_mma.shape_mnk, mode=[2])
qk_mma_tiler = (128, 128, qk_inst_k * 4)
b_smem_s = utils.sm100.make_smem_layout_b(qk_mma, qk_mma_tiler, BFloat16, 1)
print(f"QK B SMEM: outer={b_smem_s.outer}, inner={b_smem_s.inner}")
# PV (128, 64)
pv_mma_64 = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, OperandMajorMode.K, v_major,
Float32, tcgen05.CtaGroup.ONE, (128, 64), tcgen05.OperandSource.TMEM)
pv_mma_tiler_64 = (128, 64, 128)
v_smem_64 = utils.sm100.make_smem_layout_b(pv_mma_64, pv_mma_tiler_64, BFloat16, 1)
print(f"PV(128,64) V SMEM: outer={v_smem_64.outer}, inner={v_smem_64.inner}")
# PV (128, 128)
pv_mma_128 = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, OperandMajorMode.K, v_major,
Float32, tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.TMEM)
pv_mma_tiler_128 = (128, 128, 128)
v_smem_128 = utils.sm100.make_smem_layout_b(pv_mma_128, pv_mma_tiler_128, BFloat16, 1)
print(f"PV(128,128) V SMEM: outer={v_smem_128.outer}, inner={v_smem_128.inner}")
# Also print the PV MMA atom shapes
print(f"PV(128,64) MMA shape_mnk={pv_mma_64.shape_mnk}")
print(f"PV(128,128) MMA shape_mnk={pv_mma_128.shape_mnk}")
torch.manual_seed(42)
m, n, head_dim = 128, 128, 64
q = torch.randn(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(n, head_dim, 1, dtype=torch.bfloat16, device='cuda')
v_data = torch.zeros(head_dim, n, dtype=torch.bfloat16, device='cuda')
v_data[0, 0] = 1.0
v = v_data.as_strided((head_dim, n), (1, head_dim)).unsqueeze(-1)
c = torch.zeros(m, head_dim, 1, dtype=torch.bfloat16, device='cuda')
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
kernel = SmemLayoutKernel()
compiled = cute.compile(kernel, mQ, mK, mV, mC, stream)