Files
nvfp4-megamoe-kernel/tests/archive/diag_tma_layout.py
biondizzle 524f0bdfb4 Clean up: archive diagnostics and superseded tests
Kept:
- example10 (CUTLASS LLM, O rescale + final normalize)
- example9 (SSA kv_coord version)
- working_softmax_maybe.py (working softmax snapshot from before the nuke)
- test_fmha_v3_stage_c.py (identity softmax baseline, n=128 cos 0.999998)
- test_fmha_v3.py (Stage A+B baseline)
- layertest.py, cudagraph_test.py (required)
- test_cutedsl.py, test_fp4_roundtrip.py (NVFP4 tests)

Archived: diag_tma_*, example8, test_diag_multitile, test_reference_fmha,
test_ref_minimal, test_tma_coord, test_fmha_v3_diag*, test_fmha_v3_12w,
test_dense_router, test_interleave*, test_fused_step1, test_router,
test_cache, test_compile_custom_op, test_custom_op, test_layer_schedule
2026-05-23 00:17:07 +00:00

117 lines
4.6 KiB
Python

"""Diagnostic: Check tBgK layout strides to see if the GMEM tile dim is degenerate."""
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass import Float32, BFloat16, Int32
from cutlass.utils import LayoutEnum
import cutlass.torch as ct
import math
HEAD_DIM = 64
n = 256
q = torch.randn(128, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(n, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(n, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
v_kernel = v.unsqueeze(-1)
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
qk_mma = utils.sm100.make_trivial_tiled_mma(BFloat16, BFloat16, cute.nvgpu.OperandMajorMode.K, cute.nvgpu.OperandMajorMode.K, Float32, tcgen05.CtaGroup.ONE, (128,128), tcgen05.OperandSource.SMEM)
pv_mma = utils.sm100.make_trivial_tiled_mma(BFloat16, BFloat16, cute.nvgpu.OperandMajorMode.K, cute.nvgpu.OperandMajorMode.MN, Float32, tcgen05.CtaGroup.ONE, (128,HEAD_DIM), tcgen05.OperandSource.TMEM)
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
qk_mma_tiler = (128, 128, qk_ik * 4)
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
pv_mma_tiler = (128, HEAD_DIM, pv_ik * (128 // pv_ik))
cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
kv_stage = 2; q_stage = 1
k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, qk_mma_tiler, BFloat16, kv_stage)
v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, pv_mma_tiler, BFloat16, kv_stage)
k_s = cute.slice_(k_smem_s,(None,None,None,0))
v_s = cute.slice_(v_smem_s,(None,None,None,0))
tma_k, mK_tma = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(cluster_layout_vmnk.shape, qk_mma.thr_id),
mK, k_s, qk_mma_tiler, qk_mma, cluster_layout_vmnk.shape
)
tma_v, mV_tma = cute.nvgpu.make_tiled_tma_atom_B(
utils.sm100.cluster_shape_to_tma_atom_B(cluster_layout_vmnk.shape, pv_mma.thr_id),
mV, v_s, pv_mma_tiler, pv_mma, cluster_layout_vmnk.shape
)
gK = cute.local_tile(mK_tma, cute.slice_(qk_mma_tiler,(0,None,None)),(None,None,None))
gV = cute.local_tile(mV_tma, cute.slice_(pv_mma_tiler,(0,None,None)),(None,None,None))
print(f'gK shape: {cute.shape(gK)}')
print(f'gK layout: {gK.layout}')
print(f'gK stride: {gK.layout.stride}')
print(f'gK size per mode: {[cute.size(gK, mode=[i]) for i in range(len(cute.shape(gK)))]}')
print()
print(f'gV shape: {cute.shape(gV)}')
print(f'gV layout: {gV.layout}')
print(f'gV stride: {gV.layout.stride}')
print()
qk_thr = qk_mma.get_slice(0)
pv_thr = pv_mma.get_slice(0)
tCgK = qk_thr.partition_B(gK)
tCgV = pv_thr.partition_B(gV)
print(f'tCgK shape: {cute.shape(tCgK)}')
print(f'tCgK layout: {tCgK.layout}')
print(f'tCgK stride: {tCgK.layout.stride}')
print()
print(f'tCgV shape: {cute.shape(tCgV)}')
print(f'tCgV layout: {tCgV.layout}')
print(f'tCgV stride: {tCgV.layout.stride}')
print()
sK = cute.make_tensor(BFloat16, k_s)
sV = cute.make_tensor(BFloat16, v_s)
b_lay = cute.make_layout(cute.slice_(cluster_layout_vmnk,(0,None,0,0)).shape)
tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay, cute.group_modes(sK,0,3), cute.group_modes(tCgK,0,3))
tVsV, tVgV = cpasync.tma_partition(tma_v, 0, b_lay, cute.group_modes(sV,0,3), cute.group_modes(tCgV,0,3))
print(f'=== tBgK (K TMA partition) ===')
print(f'shape: {cute.shape(tBgK)}')
print(f'layout: {tBgK.layout}')
print(f'stride: {tBgK.layout.stride}')
print(f'size per mode: {[cute.size(tBgK, mode=[i]) for i in range(len(cute.shape(tBgK)))]}')
print()
print(f'=== tVgV (V TMA partition) ===')
print(f'shape: {cute.shape(tVgV)}')
print(f'layout: {tVgV.layout}')
print(f'stride: {tVgV.layout.stride}')
print(f'size per mode: {[cute.size(tVgV, mode=[i]) for i in range(len(cute.shape(tVgV)))]}')
print()
# Now check the slices
print(f'=== tBgK after (None,None,0,0) ===')
tBgK_nn = tBgK[(None,None,0,0)]
print(f'shape: {cute.shape(tBgK_nn)}')
print(f'layout: {tBgK_nn.layout}')
print(f'stride: {tBgK_nn.layout.stride}')
print()
print(f'=== tBgK after (None,0,None,0) ===')
tBgK_n0 = tBgK[(None,0,None,0)]
print(f'shape: {cute.shape(tBgK_n0)}')
print(f'layout: {tBgK_n0.layout}')
print(f'stride: {tBgK_n0.layout.stride}')
print()
print(f'=== tVgV after (None,0,None,0) ===')
tVgV_n0 = tVgV[(None,0,None,0)]
print(f'shape: {cute.shape(tVgV_n0)}')
print(f'layout: {tVgV_n0.layout}')
print(f'stride: {tVgV_n0.layout.stride}')
print(f'=== tVgV after (None,None,0,0) ===')
tVgV_nn = tVgV[(None,None,0,0)]
print(f'shape: {cute.shape(tVgV_nn)}')
print(f'layout: {tVgV_nn.layout}')
print(f'stride: {tVgV_nn.layout.stride}')