D1.3: Layout diagnostic v2 - run inside JIT-compiled kernel

This commit is contained in:
2026-05-23 22:16:57 +00:00
parent ec8fd1474c
commit d3d0020b4e

View File

@@ -1,241 +1,182 @@
"""
D1.3 SMEM-P diagnostic: print ALL relevant shapes and layouts.
Goal: Understand the exact QK C-fragment partition (source) and
PV A-operand SMEM layout (destination) so we can write a proper
register→SMEM copy for P values.
This test ONLY prints shapes. No kernel execution. Pure layout analysis.
Uses FmhaKernel._setup to compute layouts the same way the kernel does.
Runs layout analysis inside a JIT-compiled function to get MLIR context.
"""
import torch, math
import torch, math, sys
import cutlass, cutlass.cute as cute, cutlass.utils as utils
from cutlass.cute.nvgpu import tcgen05, cpasync
from cutlass import Float32, BFloat16, Int32
from cutlass import Float32, BFloat16, Int32, const_expr
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cutlass.torch as ct
def analyze_layouts(hd, s_k=128):
print(f"\n{'='*60}")
print(f" HEAD_DIM={hd}, s_k={s_k}")
print(f"{'='*60}")
def main():
for hd in [64, 256]:
print(f"\n{'='*60}")
print(f" HEAD_DIM={hd}")
print(f"{'='*60}")
m = 128
use_smem_p = hd > 64
pv_n_tile = min(hd, 256)
n_kv_tiles = s_k // 128
m = 128; s_k = 128
use_smem_p = hd > 64
pv_n_tile = min(hd, 256)
# Create dummy tensors just for layout computation
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda')
c = torch.randn(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(s_k, pv_n_tile, dtype=torch.bfloat16, device='cuda')
c = torch.zeros(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
a_major = LayoutEnum.from_tensor(ct.from_dlpack(q)).mma_major_mode()
b_major = LayoutEnum.from_tensor(ct.from_dlpack(k)).mma_major_mode()
c_major = LayoutEnum.from_tensor(ct.from_dlpack(c)).mma_major_mode()
# Use FmhaKernel's actual __call__ path to compute layouts
from dsv4.kernels.attention.fmha import FmhaKernel
kern = FmhaKernel(head_dim=hd, s_k=s_k)
# Build QK MMA
qk_mma = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, a_major, b_major, Float32,
tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM,
)
# Trigger _setup by calling compile with a dummy kernel
# Actually, we can access the _setup by calling __call__ on the kernel
# But that runs the full kernel. Let me manually call _setup by
# building the MMAs the same way __call__ does.
# Build PV MMA (SMEM-P for hd>64, TMEM-P for hd<=64)
pv_a_major = a_major if use_smem_p else cute.nvgpu.OperandMajorMode.K
pv_source = tcgen05.OperandSource.SMEM if use_smem_p else tcgen05.OperandSource.TMEM
pv_mma = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, pv_a_major, c_major, Float32,
tcgen05.CtaGroup.ONE, (128, pv_n_tile), pv_source,
)
a_major = LayoutEnum.from_tensor(ct.from_dlpack(q)).mma_major_mode()
b_major = LayoutEnum.from_tensor(ct.from_dlpack(k)).mma_major_mode()
# Tiler computation (same as FmhaKernel._setup)
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
qk_mma_tiler = (128, 128, qk_ik * 4)
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
pv_mma_tiler = (128, pv_n_tile, pv_ik * (128 // pv_ik))
# Create a minimal analysis kernel that just prints shapes
@cute.kernel
def _diag_kernel(_q, _k, _v, _c):
# Build QK MMA
qk_mma = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, a_major, b_major, Float32,
tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM,
)
print(f"\n--- MMA info ---")
print(f" QK MMA shape_mnk: {qk_mma.shape_mnk}")
print(f" PV MMA shape_mnk: {pv_mma.shape_mnk}")
print(f" QK tiler: {qk_mma_tiler}")
print(f" PV tiler: {pv_mma_tiler}")
print(f" use_smem_p: {use_smem_p}")
print(f" PV A source: {pv_source}")
print(f" PV A major: {pv_a_major}")
pv_a_major = a_major if use_smem_p else cute.nvgpu.OperandMajorMode.K
pv_source = tcgen05.OperandSource.SMEM if use_smem_p else tcgen05.OperandSource.TMEM
c_major = LayoutEnum.from_tensor(ct.from_dlpack(_c)).mma_major_mode()
# QK C-fragment (where S lives in TMEM, where softmax reads P from)
qk_thr = qk_mma.get_slice(0)
qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
print(f"\n--- QK C-fragment (tStS = S and P in TMEM) ---")
print(f" qk_as (partition shape): {qk_as}")
print(f" tStS shape: {cute.shape(tStS)}")
print(f" tStS size: {cute.size(tStS)}")
print(f" tStS layout: {tStS.layout}")
pv_mma = utils.sm100.make_trivial_tiled_mma(
BFloat16, BFloat16, pv_a_major, c_major, Float32,
tcgen05.CtaGroup.ONE, (128, pv_n_tile), pv_source,
)
# PV C-fragment (where O lives in TMEM)
pv_thr = pv_mma.get_slice(0)
pv_as = pv_thr.partition_shape_C(pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_as)
print(f"\n--- PV C-fragment (tOtO = O in TMEM) ---")
print(f" pv_as (partition shape): {pv_as}")
print(f" tOtO shape: {cute.shape(tOtO)}")
print(f" tOtO size: {cute.size(tOtO)}")
print(f" tOtO layout: {tOtO.layout}")
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
qk_mma_tiler = (128, 128, qk_ik * 4)
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
pv_mma_tiler = (128, pv_n_tile, pv_ik * (128 // pv_ik))
# PV A-operand SMEM layout (where P must be written for SMEM-P)
p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
print(f"\n--- PV A-operand SMEM layout (sP) ---")
print(f" p_smem_s outer: {p_smem_s.outer}")
print(f" p_smem_s inner (swizzle): {p_smem_s.inner}")
outer_layout = cute.make_layout(p_smem_s.outer)
print(f" p_smem_s outer shape: {cute.shape(outer_layout)}")
print(f" p_smem_s outer size: {cute.size(outer_layout)}")
print(f" QK MMA shape_mnk: {qk_mma.shape_mnk}")
print(f" PV MMA shape_mnk: {pv_mma.shape_mnk}")
print(f" QK tiler: {qk_mma_tiler}")
print(f" PV tiler: {pv_mma_tiler}")
print(f" use_smem_p: {use_smem_p}")
# PV A-operand fragments (what MMA reads from TMEM or SMEM)
# For TMEM-P: tOrP reads from TMEM
# For SMEM-P: tCrP reads from SMEM, tOrP also reads from SMEM
p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
tP_tmem = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tP_smem = cute.make_tensor(cute.find(cute.make_layout(p_smem_s.outer), 0), p_smem_s.outer)
# QK C-fragment
qk_thr = qk_mma.get_slice(0)
qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
print(f" tStS shape: {cute.shape(tStS)} layout: {tStS.layout}")
tOrP_tmem = pv_thr.make_fragment_A(tP_tmem)
tCrP_smem = pv_mma.make_fragment_A(tP_smem)
tOrP_smem = pv_thr.make_fragment_A(tP_smem)
# PV C-fragment
pv_thr = pv_mma.get_slice(0)
pv_as = pv_thr.partition_shape_C(pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_as)
print(f" tOtO shape: {cute.shape(tOtO)} layout: {tOtO.layout}")
print(f"\n--- PV A-operand fragments ---")
print(f" tOrP_tmem (TMEM-P A-read): shape={cute.shape(tOrP_tmem)}")
print(f" tCrP_smem (SMEM-P A-load): shape={cute.shape(tCrP_smem)} layout={tCrP_smem.layout}")
print(f" tOrP_smem (SMEM-P A-read): shape={cute.shape(tOrP_smem)} layout={tOrP_smem.layout}")
print(f" tOrP_tmem size: {cute.size(tOrP_tmem)}")
print(f" tCrP_smem size: {cute.size(tCrP_smem)}")
# PV A-operand SMEM layout (sP)
p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
sP_alloc = cute.make_tensor(cute.find(cute.make_layout(p_smem_s.outer), 0), p_smem_s.outer)
print(f" sP shape: {cute.shape(sP_alloc)}")
sP_2d = cute.group_modes(sP_alloc, 0, 3)
print(f" sP_2d shape: {cute.shape(sP_2d)} rank={len(cute.shape(sP_2d))}")
print(f" sP_2d layout: {sP_2d.layout}")
# Slice tOrP the way the kernel does: [(None,None,None,0)]
tOrP_base_tmem = pv_thr.make_fragment_A(tP_tmem)
tOrP_tmem_sliced = tOrP_base_tmem[(None,None,None,0)]
tOrP_base_smem = pv_thr.make_fragment_A(tP_smem)
tOrP_smem_sliced = tOrP_base_smem[(None,None,None,0)]
print(f"\n--- PV A-operand after slice [(None,None,None,0)] ---")
print(f" tOrP_tmem_sliced: shape={cute.shape(tOrP_tmem_sliced)}")
print(f" tOrP_smem_sliced: shape={cute.shape(tOrP_smem_sliced)}")
# PV A-operand fragments
tP_tmem = cute.make_tensor(tStS.iterator, p_smem_s.outer) # dummy tmem ptr
tP_smem = cute.make_tensor(cute.find(cute.make_layout(p_smem_s.outer), 0), p_smem_s.outer)
tOrP_tmem = pv_thr.make_fragment_A(tP_tmem)
tCrP_smem = pv_mma.make_fragment_A(tP_smem)
tOrP_smem = pv_thr.make_fragment_A(tP_smem)
print(f" tOrP_tmem shape: {cute.shape(tOrP_tmem)}")
print(f" tCrP_smem shape: {cute.shape(tCrP_smem)} layout: {tCrP_smem.layout}")
print(f" tOrP_smem shape: {cute.shape(tOrP_smem)}")
# Softmax thread partition of S (via TMEM load atoms)
sfw_idx = 0
tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS)
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS)
cS = cute.make_identity_tensor((128, 128))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
# Slice as kernel does
tOrP_tmem_s = tOrP_tmem[(None,None,None,0)]
tOrP_smem_s = tOrP_smem[(None,None,None,0)]
print(f" tOrP_tmem sliced: {cute.shape(tOrP_tmem_s)}")
print(f" tOrP_smem sliced: {cute.shape(tOrP_smem_s)}")
print(f"\n--- Softmax thread 0: TMEM load partition ---")
print(f" tTMEM_LOADtS shape: {cute.shape(tTMEM_LOADtS)}")
print(f" tTMEM_LOADcS shape: {cute.shape(tTMEM_LOADcS)}")
print(f" tTMEM_LOADcS layout: {tTMEM_LOADcS.layout}")
print(f" => Each softmax thread has {cute.size(tTMEM_LOADcS)} S/P coordinates")
# Softmax TMEM load partition
sfw_idx = 0
tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS)
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS)
cS = cute.make_identity_tensor((128, 128))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
print(f" tTMEM_LOADcS (coord) shape: {cute.shape(tTMEM_LOADcS)} layout: {tTMEM_LOADcS.layout}")
# P register bridge: what P values does each softmax thread have?
# After softmax, P is in rP_bf16 which shares layout with rS
# rP_bf16 = make_tensor(recast_ptr(rP_words.iterator, dtype=BF16), tTMEM_LOADrS.layout)
# where tTMEM_LOADrS = make_rmem_tensor(tTMEM_LOADcS.shape, FP32)
p_cols_fp32 = pv_mma_tiler[2] * BFloat16.width // Float32.width
tStP_layout = cute.composition(tStS.layout, cute.make_layout((pv_mma_tiler[0], p_cols_fp32)))
print(f"\n--- P columns in TMEM ---")
print(f" p_cols_fp32 (PV tiler[2] * BF16/FP32): {p_cols_fp32}")
print(f" P TMEM sub-layout (composition of tStS): {tStP_layout}")
print(f" P TMEM sub-layout shape: {cute.shape(cute.make_layout(tStP_layout))}")
# P sub-layout in TMEM
p_cols_fp32 = pv_mma_tiler[2] * BFloat16.width // Float32.width
print(f" p_cols_fp32: {p_cols_fp32}")
# Now the critical question: what's the sP shape that the softmax needs to write to?
sP_alloc = cute.make_tensor(cute.find(cute.make_layout(p_smem_s.outer), 0), p_smem_s.outer)
sP_2d = cute.group_modes(sP_alloc, 0, 3)
print(f"\n--- sP for SMEM-P write target ---")
print(f" sP shape (3D): {cute.shape(sP_alloc)}")
print(f" sP_2d shape (grouped): {cute.shape(sP_2d)}")
print(f" sP_2d rank: {len(cute.shape(sP_2d))}")
# make_tiled_copy_C with qk_mma
print(f"\n --- make_tiled_copy_C(qk_mma) ---")
try:
smem_copy_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), BFloat16, num_bits_per_copy=128)
tiled_copy = cute.make_tiled_copy_C(smem_copy_atom, qk_mma)
thr_copy = tiled_copy.get_slice(sfw_idx)
print(f" OK, created tiled_copy")
try:
tD = thr_copy.partition_D(sP_2d)
print(f" partition_D(sP_2d) shape: {cute.shape(tD)} rank={len(cute.shape(tD))}")
except Exception as e:
print(f" partition_D(sP_2d) FAILED: {e}")
# Try make_tiled_copy_C with QK MMA (the approach that gave rank mismatch)
print(f"\n--- make_tiled_copy_C(smem_copy_atom, qk_mma) ---")
try:
smem_copy_atom = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
BFloat16,
num_bits_per_copy=128,
)
tiled_copy_qk = cute.make_tiled_copy_C(smem_copy_atom, qk_mma)
thr_qk = tiled_copy_qk.get_slice(sfw_idx)
tSMEM_QKsP = thr_qk.partition_D(sP_2d)
print(f" partition_D(sP_2d) shape: {cute.shape(tSMEM_QKsP)} rank={len(cute.shape(tSMEM_QKsP))}")
# Try source partition with coord layout
try:
rP_coord = cute.make_tensor(cute.find(tTMEM_LOADcS.layout, 0), tTMEM_LOADcS.layout)
tS = thr_copy.partition_S(rP_coord)
print(f" partition_S(coord) shape: {cute.shape(tS)} rank={len(cute.shape(tS))}")
except Exception as e:
print(f" partition_S(coord) FAILED: {e}")
# Try partition_S with various source layouts
# 1. Identity tensor (coordinate) layout
try:
rP_coord = cute.make_tensor(cute.find(tTMEM_LOADcS.layout, 0), tTMEM_LOADcS.layout)
tSMEM_QKrP = thr_qk.partition_S(rP_coord)
print(f" partition_S(coord layout) shape: {cute.shape(tSMEM_QKrP)} rank={len(cute.shape(tSMEM_QKrP))}")
except Exception as e:
print(f" partition_S(coord layout) FAILED: {e}")
# Try source with QK C-frag layout
try:
rP_qk = cute.make_tensor(cute.find(tStS.layout, 0), tStS.layout)
tS = thr_copy.partition_S(rS=rP_qk)
print(f" partition_S(QK C-frag) shape: {cute.shape(tS)} rank={len(cute.shape(tS))}")
except Exception as e:
print(f" partition_S(QK C-frag) FAILED: {e}")
except Exception as e:
print(f" make_tiled_copy_C FAILED: {e}")
# 2. QK C-fragment layout (tStS)
try:
rP_qk = cute.make_tensor(cute.find(tStS.layout, 0), tStS.layout)
tSMEM_QKrP = thr_qk.partition_S(rP_qk)
print(f" partition_S(QK C-frag layout) shape: {cute.shape(tSMEM_QKrP)} rank={len(cute.shape(tSMEM_QKrP))}")
except Exception as e:
print(f" partition_S(QK C-frag layout) FAILED: {e}")
# make_tiled_copy_C with pv_mma
print(f"\n --- make_tiled_copy_C(pv_mma) ---")
try:
tiled_copy_pv = cute.make_tiled_copy_C(smem_copy_atom, pv_mma)
thr_copy_pv = tiled_copy_pv.get_slice(sfw_idx)
print(f" OK, created tiled_copy_pv")
try:
tD = thr_copy_pv.partition_D(sP_2d)
print(f" partition_D(sP_2d) shape: {cute.shape(tD)} rank={len(cute.shape(tD))}")
except Exception as e:
print(f" partition_D(sP_2d) FAILED: {e}")
except Exception as e:
print(f" make_tiled_copy_C(pv_mma) FAILED: {e}")
# 3. The rmem layout that softmax actually has P in
rP_rmem_layout = cute.make_layout(tTMEM_LOADcS.shape)
try:
rP_rmem = cute.make_tensor(cute.find(rP_rmem_layout, 0), rP_rmem_layout)
tSMEM_QKrP = thr_qk.partition_S(rP_rmem)
print(f" partition_S(rmem layout) shape: {cute.shape(tSMEM_QKrP)} rank={len(cute.shape(tSMEM_QKrP))}")
except Exception as e:
print(f" partition_S(rmem layout) FAILED: {e}")
# Compile and run
import cuda.bindings.driver as cuda
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
except Exception as e:
print(f" make_tiled_copy_C(qk_mma) FAILED: {e}")
mQ = ct.from_dlpack(q).mark_layout_dynamic(leading_dim=ct.get_leading_dim(q))
mK = ct.from_dlpack(k).mark_layout_dynamic(leading_dim=ct.get_leading_dim(k))
mV = ct.from_dlpack(v.unsqueeze(-1)).mark_layout_dynamic(leading_dim=ct.get_leading_dim(v.unsqueeze(-1)))
mC = ct.from_dlpack(c).mark_layout_dynamic(leading_dim=ct.get_leading_dim(c))
# Try make_tiled_copy_C with PV MMA (different approach)
print(f"\n--- make_tiled_copy_C(smem_copy_atom, pv_mma) ---")
try:
smem_copy_atom = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
BFloat16,
num_bits_per_copy=128,
)
tiled_copy_pv = cute.make_tiled_copy_C(smem_copy_atom, pv_mma)
thr_pv = tiled_copy_pv.get_slice(sfw_idx)
tSMEM_PVsP = thr_pv.partition_D(sP_2d)
print(f" partition_D(sP_2d) shape: {cute.shape(tSMEM_PVsP)} rank={len(cute.shape(tSMEM_PVsP))}")
except Exception as e:
print(f" make_tiled_copy_C(pv_mma) FAILED: {e}")
# Try make_tiled_copy (not _C) with explicit M/N layout
print(f"\n--- make_tiled_copy (manual M/N) ---")
try:
# The P matrix is (128, n_kv) = (M=128, N_P=128)
# Use (128, 128) tile for the copy
MN = cute.make_layout((128, 128))
smem_copy_atom2 = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
BFloat16,
num_bits_per_copy=128,
)
# Need to use make_tiled_copy with proper thread arrangement
# This is for the softmax warps (128 threads = 4 warps)
thr_layout = cute.make_layout((4, 32)) # 4 warps × 32 lanes
val_layout = cute.make_layout((1, 1)) # each thread copies 1 element
tiled_manual = cute.make_tiled_copy(smem_copy_atom2, thr_layout, val_layout)
thr_manual = tiled_manual.get_slice(sfw_idx)
tMANUALsP = thr_manual.partition_D(sP_2d)
print(f" partition_D(sP_2d) shape: {cute.shape(tMANUALsP)} rank={len(cute.shape(tMANUALsP))}")
except Exception as e:
print(f" make_tiled_copy (manual) FAILED: {e}")
print(' Compiling diagnostic...', flush=True)
compiled = cute.compile(_diag_kernel, mQ, mK, mV, mC)
compiled(mQ, mK, mV, mC)
torch.cuda.synchronize()
if __name__ == '__main__':
analyze_layouts(64) # TMEM-P (works)
analyze_layouts(256) # SMEM-P (broken)
main()