Kept: - example10 (CUTLASS LLM, O rescale + final normalize) - example9 (SSA kv_coord version) - working_softmax_maybe.py (working softmax snapshot from before the nuke) - test_fmha_v3_stage_c.py (identity softmax baseline, n=128 cos 0.999998) - test_fmha_v3.py (Stage A+B baseline) - layertest.py, cudagraph_test.py (required) - test_cutedsl.py, test_fp4_roundtrip.py (NVFP4 tests) Archived: diag_tma_*, example8, test_diag_multitile, test_reference_fmha, test_ref_minimal, test_tma_coord, test_fmha_v3_diag*, test_fmha_v3_12w, test_dense_router, test_interleave*, test_fused_step1, test_router, test_cache, test_compile_custom_op, test_custom_op, test_layer_schedule
117 lines
4.6 KiB
Python
117 lines
4.6 KiB
Python
"""Diagnostic: Check tBgK layout strides to see if the GMEM tile dim is degenerate."""
|
|
import torch, cutlass, cutlass.cute as cute, cutlass.utils as utils
|
|
from cutlass.cute.nvgpu import cpasync, tcgen05
|
|
from cutlass import Float32, BFloat16, Int32
|
|
from cutlass.utils import LayoutEnum
|
|
import cutlass.torch as ct
|
|
import math
|
|
|
|
HEAD_DIM = 64
|
|
n = 256
|
|
|
|
q = torch.randn(128, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda')
|
|
k = torch.randn(n, HEAD_DIM, 1, dtype=torch.bfloat16, device='cuda')
|
|
v = torch.randn(n, HEAD_DIM, dtype=torch.bfloat16, device='cuda')
|
|
v_kernel = v.unsqueeze(-1)
|
|
|
|
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
|
|
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
|
|
mV = ct.from_dlpack(v_kernel).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v_kernel))
|
|
|
|
qk_mma = utils.sm100.make_trivial_tiled_mma(BFloat16, BFloat16, cute.nvgpu.OperandMajorMode.K, cute.nvgpu.OperandMajorMode.K, Float32, tcgen05.CtaGroup.ONE, (128,128), tcgen05.OperandSource.SMEM)
|
|
pv_mma = utils.sm100.make_trivial_tiled_mma(BFloat16, BFloat16, cute.nvgpu.OperandMajorMode.K, cute.nvgpu.OperandMajorMode.MN, Float32, tcgen05.CtaGroup.ONE, (128,HEAD_DIM), tcgen05.OperandSource.TMEM)
|
|
|
|
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
|
|
qk_mma_tiler = (128, 128, qk_ik * 4)
|
|
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
|
|
pv_mma_tiler = (128, HEAD_DIM, pv_ik * (128 // pv_ik))
|
|
cluster_layout_vmnk = cute.tiled_divide(cute.make_layout((1,1,1)), (qk_mma.thr_id.shape,))
|
|
|
|
kv_stage = 2; q_stage = 1
|
|
k_smem_s = utils.sm100.make_smem_layout_b(qk_mma, qk_mma_tiler, BFloat16, kv_stage)
|
|
v_smem_s = utils.sm100.make_smem_layout_b(pv_mma, pv_mma_tiler, BFloat16, kv_stage)
|
|
|
|
k_s = cute.slice_(k_smem_s,(None,None,None,0))
|
|
v_s = cute.slice_(v_smem_s,(None,None,None,0))
|
|
|
|
tma_k, mK_tma = cute.nvgpu.make_tiled_tma_atom_B(
|
|
utils.sm100.cluster_shape_to_tma_atom_B(cluster_layout_vmnk.shape, qk_mma.thr_id),
|
|
mK, k_s, qk_mma_tiler, qk_mma, cluster_layout_vmnk.shape
|
|
)
|
|
tma_v, mV_tma = cute.nvgpu.make_tiled_tma_atom_B(
|
|
utils.sm100.cluster_shape_to_tma_atom_B(cluster_layout_vmnk.shape, pv_mma.thr_id),
|
|
mV, v_s, pv_mma_tiler, pv_mma, cluster_layout_vmnk.shape
|
|
)
|
|
|
|
gK = cute.local_tile(mK_tma, cute.slice_(qk_mma_tiler,(0,None,None)),(None,None,None))
|
|
gV = cute.local_tile(mV_tma, cute.slice_(pv_mma_tiler,(0,None,None)),(None,None,None))
|
|
|
|
print(f'gK shape: {cute.shape(gK)}')
|
|
print(f'gK layout: {gK.layout}')
|
|
print(f'gK stride: {gK.layout.stride}')
|
|
print(f'gK size per mode: {[cute.size(gK, mode=[i]) for i in range(len(cute.shape(gK)))]}')
|
|
print()
|
|
print(f'gV shape: {cute.shape(gV)}')
|
|
print(f'gV layout: {gV.layout}')
|
|
print(f'gV stride: {gV.layout.stride}')
|
|
print()
|
|
|
|
qk_thr = qk_mma.get_slice(0)
|
|
pv_thr = pv_mma.get_slice(0)
|
|
tCgK = qk_thr.partition_B(gK)
|
|
tCgV = pv_thr.partition_B(gV)
|
|
|
|
print(f'tCgK shape: {cute.shape(tCgK)}')
|
|
print(f'tCgK layout: {tCgK.layout}')
|
|
print(f'tCgK stride: {tCgK.layout.stride}')
|
|
print()
|
|
print(f'tCgV shape: {cute.shape(tCgV)}')
|
|
print(f'tCgV layout: {tCgV.layout}')
|
|
print(f'tCgV stride: {tCgV.layout.stride}')
|
|
print()
|
|
|
|
sK = cute.make_tensor(BFloat16, k_s)
|
|
sV = cute.make_tensor(BFloat16, v_s)
|
|
b_lay = cute.make_layout(cute.slice_(cluster_layout_vmnk,(0,None,0,0)).shape)
|
|
|
|
tBsK, tBgK = cpasync.tma_partition(tma_k, 0, b_lay, cute.group_modes(sK,0,3), cute.group_modes(tCgK,0,3))
|
|
tVsV, tVgV = cpasync.tma_partition(tma_v, 0, b_lay, cute.group_modes(sV,0,3), cute.group_modes(tCgV,0,3))
|
|
|
|
print(f'=== tBgK (K TMA partition) ===')
|
|
print(f'shape: {cute.shape(tBgK)}')
|
|
print(f'layout: {tBgK.layout}')
|
|
print(f'stride: {tBgK.layout.stride}')
|
|
print(f'size per mode: {[cute.size(tBgK, mode=[i]) for i in range(len(cute.shape(tBgK)))]}')
|
|
print()
|
|
print(f'=== tVgV (V TMA partition) ===')
|
|
print(f'shape: {cute.shape(tVgV)}')
|
|
print(f'layout: {tVgV.layout}')
|
|
print(f'stride: {tVgV.layout.stride}')
|
|
print(f'size per mode: {[cute.size(tVgV, mode=[i]) for i in range(len(cute.shape(tVgV)))]}')
|
|
print()
|
|
|
|
# Now check the slices
|
|
print(f'=== tBgK after (None,None,0,0) ===')
|
|
tBgK_nn = tBgK[(None,None,0,0)]
|
|
print(f'shape: {cute.shape(tBgK_nn)}')
|
|
print(f'layout: {tBgK_nn.layout}')
|
|
print(f'stride: {tBgK_nn.layout.stride}')
|
|
print()
|
|
print(f'=== tBgK after (None,0,None,0) ===')
|
|
tBgK_n0 = tBgK[(None,0,None,0)]
|
|
print(f'shape: {cute.shape(tBgK_n0)}')
|
|
print(f'layout: {tBgK_n0.layout}')
|
|
print(f'stride: {tBgK_n0.layout.stride}')
|
|
print()
|
|
|
|
print(f'=== tVgV after (None,0,None,0) ===')
|
|
tVgV_n0 = tVgV[(None,0,None,0)]
|
|
print(f'shape: {cute.shape(tVgV_n0)}')
|
|
print(f'layout: {tVgV_n0.layout}')
|
|
print(f'stride: {tVgV_n0.layout.stride}')
|
|
print(f'=== tVgV after (None,None,0,0) ===')
|
|
tVgV_nn = tVgV[(None,None,0,0)]
|
|
print(f'shape: {cute.shape(tVgV_nn)}')
|
|
print(f'layout: {tVgV_nn.layout}')
|
|
print(f'stride: {tVgV_nn.layout.stride}')
|