D1.3: Fix layout diagnostic - remove JIT-dependent code

This commit is contained in:
2026-05-23 22:15:47 +00:00
parent c3a7c30f20
commit ec8fd1474c

View File

@@ -6,11 +6,14 @@ PV A-operand SMEM layout (destination) so we can write a proper
register→SMEM copy for P values.
This test ONLY prints shapes. No kernel execution. Pure layout analysis.
Uses FmhaKernel._setup to compute layouts the same way the kernel does.
"""
import torch, math
import cutlass, cutlass.cute as cute, cutlass.utils as utils
from cutlass.cute.nvgpu import tcgen05
from cutlass.cute.nvgpu import tcgen05, cpasync
from cutlass import Float32, BFloat16, Int32
from cutlass.utils import LayoutEnum
from cutlass.utils.tmem_allocator import find_tmem_tensor_col_offset
import cutlass.torch as ct
@@ -22,94 +25,79 @@ def analyze_layouts(hd, s_k=128):
m = 128
use_smem_p = hd > 64
pv_n_tile = min(hd, 256)
n_kv_tiles = s_k // 128
# Create dummy tensors just for layout computation
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda')
c = torch.randn(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda')
# Reconstruct V with FMHA layout
v_fmha = cute.make_tensor(
v.data_ptr(),
cute.make_layout((pv_n_tile, s_k, 1), stride=(1, pv_n_tile, pv_n_tile * s_k)),
)
a_major = LayoutEnum.from_tensor(ct.from_dlpack(q)).mma_major_mode()
b_major = LayoutEnum.from_tensor(ct.from_dlpack(k)).mma_major_mode()
c_major = LayoutEnum.from_tensor(ct.from_dlpack(c)).mma_major_mode()
# Build QK MMA
qk_mma = utils.sm100.make_trivial_tiled_mma(
cutlass.BFloat16, cutlass.BFloat16,
LayoutEnum.from_tensor(ct.from_dlpack(q)).mma_major_mode(),
LayoutEnum.from_tensor(ct.from_dlpack(k)).mma_mn_mode(),
cutlass.Float32, tcgen05.CtaGroup.ONE, (128, 128),
tcgen05.OperandSource.SMEM,
BFloat16, BFloat16, a_major, b_major, Float32,
tcgen05.CtaGroup.ONE, (128, 128), tcgen05.OperandSource.SMEM,
)
pv_a_major = LayoutEnum.from_tensor(ct.from_dlpack(q)).mma_major_mode() if use_smem_p else cute.nvgpu.OperandMajorMode.K
# Build PV MMA (SMEM-P for hd>64, TMEM-P for hd<=64)
pv_a_major = a_major if use_smem_p else cute.nvgpu.OperandMajorMode.K
pv_source = tcgen05.OperandSource.SMEM if use_smem_p else tcgen05.OperandSource.TMEM
pv_mma = utils.sm100.make_trivial_tiled_mma(
cutlass.BFloat16, cutlass.BFloat16,
pv_a_major,
LayoutEnum.from_tensor(ct.from_dlpack(torch.randn(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda'))).mma_major_mode(),
cutlass.Float32, tcgen05.CtaGroup.ONE, (128, pv_n_tile),
pv_source,
BFloat16, BFloat16, pv_a_major, c_major, Float32,
tcgen05.CtaGroup.ONE, (128, pv_n_tile), pv_source,
)
# MMA tiler
# Tiler computation (same as FmhaKernel._setup)
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
qk_mma_tiler = (128, 128, qk_ik * 4)
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
pv_mma_tiler = (128, pv_n_tile, pv_ik * (128 // pv_ik))
print(f"\n--- MMA info ---")
print(f"QK MMA shape_mnk: {qk_mma.shape_mnk}")
print(f"PV MMA shape_mnk: {pv_mma.shape_mnk}")
print(f"QK tiler: {qk_mma_tiler}")
print(f"PV tiler: {pv_mma_tiler}")
print(f"use_smem_p: {use_smem_p}")
print(f"PV A source: {pv_source}")
print(f" QK MMA shape_mnk: {qk_mma.shape_mnk}")
print(f" PV MMA shape_mnk: {pv_mma.shape_mnk}")
print(f" QK tiler: {qk_mma_tiler}")
print(f" PV tiler: {pv_mma_tiler}")
print(f" use_smem_p: {use_smem_p}")
print(f" PV A source: {pv_source}")
print(f" PV A major: {pv_a_major}")
# QK C-fragment (where S and P live in TMEM)
# QK C-fragment (where S lives in TMEM, where softmax reads P from)
qk_thr = qk_mma.get_slice(0)
qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
print(f"\n--- QK C-fragment (tStS) ---")
print(f"\n--- QK C-fragment (tStS = S and P in TMEM) ---")
print(f" qk_as (partition shape): {qk_as}")
print(f" tStS shape: {cute.shape(tStS)}")
print(f" tStS size: {cute.size(tStS)}")
print(f" tStS layout: {tStS.layout}")
# P in TMEM (subset of S columns)
p_cols_fp32 = pv_mma_tiler[2] * 16 // 32 # BF16 width / FP32 width
print(f" p_cols_fp32 (PV MMA tiler[2] * 16/32): {p_cols_fp32}")
p_cols_fp32_v2 = pv_mma_tiler[2] * cutlass.BFloat16.width // cutlass.Float32.width
print(f" p_cols_fp32 (PV MMA tiler[2] * BF16_width/FP32_width): {p_cols_fp32_v2}")
# PV C-fragment (where O lives in TMEM)
pv_thr = pv_mma.get_slice(0)
pv_as = pv_thr.partition_shape_C(pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_as)
print(f"\n--- PV C-fragment (tOtO) ---")
print(f"\n--- PV C-fragment (tOtO = O in TMEM) ---")
print(f" pv_as (partition shape): {pv_as}")
print(f" tOtO shape: {cute.shape(tOtO)}")
print(f" tOtO size: {cute.size(tOtO)}")
print(f" tOtO layout: {tOtO.layout}")
# PV A-operand (where P is read from in PV GEMM)
# TMEM-P: tOrP from TMEM
# SMEM-P: tCrP from SMEM, tOrP from sP
p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, cutlass.BFloat16, 1)
p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, cutlass.BFloat16, 1)
print(f"\n--- PV A-operand layouts ---")
print(f" p_tmem_s outer: {p_tmem_s.outer}")
print(f" p_tmem_s inner (swizzle): {p_smem_s.inner}")
# Create sP tensor (SMEM allocation)
# We need to allocate sP with the PV A-operand SMEM layout
# PV A-operand SMEM layout (where P must be written for SMEM-P)
p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
print(f"\n--- PV A-operand SMEM layout (sP) ---")
print(f" p_smem_s outer: {p_smem_s.outer}")
print(f" p_smem_s outer shape: {cute.shape(cute.make_layout(p_smem_s.outer))}")
print(f" p_smem_s inner: {p_smem_s.inner}")
print(f" p_smem_s inner (swizzle): {p_smem_s.inner}")
outer_layout = cute.make_layout(p_smem_s.outer)
print(f" p_smem_s outer shape: {cute.shape(outer_layout)}")
print(f" p_smem_s outer size: {cute.size(outer_layout)}")
# What does the MMA A-operand fragment look like for SMEM-P?
# This is what the MMA warp reads from sP
# We need a dummy sP tensor to create the fragment
# Use composition to get the A-operand partition
# PV A-operand fragments (what MMA reads from TMEM or SMEM)
# For TMEM-P: tOrP reads from TMEM
# For SMEM-P: tCrP reads from SMEM, tOrP also reads from SMEM
p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
tP_tmem = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tP_smem = cute.make_tensor(cute.find(cute.make_layout(p_smem_s.outer), 0), p_smem_s.outer)
@@ -118,17 +106,24 @@ def analyze_layouts(hd, s_k=128):
tOrP_smem = pv_thr.make_fragment_A(tP_smem)
print(f"\n--- PV A-operand fragments ---")
print(f" tOrP_tmem (TMEM-P read): shape={cute.shape(tOrP_tmem)} layout={tOrP_tmem.layout}")
print(f" tCrP_smem (SMEM-P load): shape={cute.shape(tCrP_smem)} layout={tCrP_smem.layout}")
print(f" tOrP_smem (SMEM-P read): shape={cute.shape(tOrP_smem)} layout={tOrP_smem.layout}")
print(f" tOrP_tmem (TMEM-P A-read): shape={cute.shape(tOrP_tmem)}")
print(f" tCrP_smem (SMEM-P A-load): shape={cute.shape(tCrP_smem)} layout={tCrP_smem.layout}")
print(f" tOrP_smem (SMEM-P A-read): shape={cute.shape(tOrP_smem)} layout={tOrP_smem.layout}")
print(f" tOrP_tmem size: {cute.size(tOrP_tmem)}")
print(f" tCrP_smem size: {cute.size(tCrP_smem)}")
# The softmax warps' view of P
# They have P in QK C-fragment layout (tStS), and they need to write to sP
# Let's understand the QK C-fragment partition per-thread
sfw_idx = 0 # thread 0 in softmax warps
# Slice tOrP the way the kernel does: [(None,None,None,0)]
tOrP_base_tmem = pv_thr.make_fragment_A(tP_tmem)
tOrP_tmem_sliced = tOrP_base_tmem[(None,None,None,0)]
tOrP_base_smem = pv_thr.make_fragment_A(tP_smem)
tOrP_smem_sliced = tOrP_base_smem[(None,None,None,0)]
print(f"\n--- PV A-operand after slice [(None,None,None,0)] ---")
print(f" tOrP_tmem_sliced: shape={cute.shape(tOrP_tmem_sliced)}")
print(f" tOrP_smem_sliced: shape={cute.shape(tOrP_smem_sliced)}")
# TMEM load/store atoms for S/P
tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), cutlass.Float32)
# Softmax thread partition of S (via TMEM load atoms)
sfw_idx = 0
tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS)
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS)
@@ -136,98 +131,109 @@ def analyze_layouts(hd, s_k=128):
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
print(f"\n--- Softmax thread 0 TMEM load partition ---")
print(f"\n--- Softmax thread 0: TMEM load partition ---")
print(f" tTMEM_LOADtS shape: {cute.shape(tTMEM_LOADtS)}")
print(f" tTMEM_LOADcS shape: {cute.shape(tTMEM_LOADcS)}")
print(f" tTMEM_LOADcS layout: {tTMEM_LOADcS.layout}")
print(f" => Each softmax thread has {cute.size(tTMEM_LOADcS)} S coordinates")
print(f" => Each softmax thread has {cute.size(tTMEM_LOADcS)} S/P coordinates")
# Now try make_tiled_copy_C with qk_mma
print(f"\n--- make_tiled_copy_C(qk_mma) attempt ---")
# P register bridge: what P values does each softmax thread have?
# After softmax, P is in rP_bf16 which shares layout with rS
# rP_bf16 = make_tensor(recast_ptr(rP_words.iterator, dtype=BF16), tTMEM_LOADrS.layout)
# where tTMEM_LOADrS = make_rmem_tensor(tTMEM_LOADcS.shape, FP32)
p_cols_fp32 = pv_mma_tiler[2] * BFloat16.width // Float32.width
tStP_layout = cute.composition(tStS.layout, cute.make_layout((pv_mma_tiler[0], p_cols_fp32)))
print(f"\n--- P columns in TMEM ---")
print(f" p_cols_fp32 (PV tiler[2] * BF16/FP32): {p_cols_fp32}")
print(f" P TMEM sub-layout (composition of tStS): {tStP_layout}")
print(f" P TMEM sub-layout shape: {cute.shape(cute.make_layout(tStP_layout))}")
# Now the critical question: what's the sP shape that the softmax needs to write to?
sP_alloc = cute.make_tensor(cute.find(cute.make_layout(p_smem_s.outer), 0), p_smem_s.outer)
sP_2d = cute.group_modes(sP_alloc, 0, 3)
print(f"\n--- sP for SMEM-P write target ---")
print(f" sP shape (3D): {cute.shape(sP_alloc)}")
print(f" sP_2d shape (grouped): {cute.shape(sP_2d)}")
print(f" sP_2d rank: {len(cute.shape(sP_2d))}")
# Try make_tiled_copy_C with QK MMA (the approach that gave rank mismatch)
print(f"\n--- make_tiled_copy_C(smem_copy_atom, qk_mma) ---")
try:
smem_copy_atom = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
cutlass.BFloat16,
BFloat16,
num_bits_per_copy=128,
)
tiled_smem_copy = cute.make_tiled_copy_C(smem_copy_atom, qk_mma)
thr_smem_copy = tiled_smem_copy.get_slice(sfw_idx)
sP_2d = cute.group_modes(cute.make_tensor(cute.find(cute.make_layout(p_smem_s.outer), 0), p_smem_s.outer), 0, 3)
tSMEM_CPYsP = thr_smem_copy.partition_D(sP_2d)
print(f" tiled_smem_copy: created OK")
print(f" sP_2d shape: {cute.shape(sP_2d)} rank: {len(cute.shape(sP_2d))}")
print(f" tSMEM_CPYsP shape: {cute.shape(tSMEM_CPYsP)} rank: {len(cute.shape(tSMEM_CPYsP))}")
tiled_copy_qk = cute.make_tiled_copy_C(smem_copy_atom, qk_mma)
thr_qk = tiled_copy_qk.get_slice(sfw_idx)
tSMEM_QKsP = thr_qk.partition_D(sP_2d)
print(f" partition_D(sP_2d) shape: {cute.shape(tSMEM_QKsP)} rank={len(cute.shape(tSMEM_QKsP))}")
# Try partition_S with the rP register layout
# rP is in the same layout as rS (TMEM load partition)
rP_layout = tTMEM_LOADcS.layout # coordinate layout
rP_bf16 = cute.make_tensor(cute.find(rP_layout, 0), rP_layout)
# Try partition_S with various source layouts
# 1. Identity tensor (coordinate) layout
try:
tSMEM_CPYrP = thr_smem_copy.partition_S(rP_bf16)
print(f" tSMEM_CPYrP (from coord layout) shape: {cute.shape(tSMEM_CPYrP)} rank: {len(cute.shape(tSMEM_CPYrP))}")
rP_coord = cute.make_tensor(cute.find(tTMEM_LOADcS.layout, 0), tTMEM_LOADcS.layout)
tSMEM_QKrP = thr_qk.partition_S(rP_coord)
print(f" partition_S(coord layout) shape: {cute.shape(tSMEM_QKrP)} rank={len(cute.shape(tSMEM_QKrP))}")
except Exception as e:
print(f" partition_S with coord layout FAILED: {e}")
print(f" partition_S(coord layout) FAILED: {e}")
# Try with tStS layout (QK C-fragment)
# 2. QK C-fragment layout (tStS)
try:
rP_qk = cute.make_tensor(cute.find(tStS.layout, 0), tStS.layout)
tSMEM_CPYrP = thr_smem_copy.partition_S(rP_qk)
print(f" tSMEM_CPYrP (from QK C-frag layout) shape: {cute.shape(tSMEM_CPYrP)} rank: {len(cute.shape(tSMEM_CPYrP))}")
tSMEM_QKrP = thr_qk.partition_S(rP_qk)
print(f" partition_S(QK C-frag layout) shape: {cute.shape(tSMEM_QKrP)} rank={len(cute.shape(tSMEM_QKrP))}")
except Exception as e:
print(f" partition_S with QK C-frag layout FAILED: {e}")
print(f" partition_S(QK C-frag layout) FAILED: {e}")
# 3. The rmem layout that softmax actually has P in
rP_rmem_layout = cute.make_layout(tTMEM_LOADcS.shape)
try:
rP_rmem = cute.make_tensor(cute.find(rP_rmem_layout, 0), rP_rmem_layout)
tSMEM_QKrP = thr_qk.partition_S(rP_rmem)
print(f" partition_S(rmem layout) shape: {cute.shape(tSMEM_QKrP)} rank={len(cute.shape(tSMEM_QKrP))}")
except Exception as e:
print(f" partition_S(rmem layout) FAILED: {e}")
except Exception as e:
print(f" make_tiled_copy_C FAILED: {e}")
print(f" make_tiled_copy_C(qk_mma) FAILED: {e}")
# Now try make_tiled_copy (NOT make_tiled_copy_C) with pv_mma
print(f"\n--- make_tiled_copy with PV A-operand SMEM layout ---")
# Try make_tiled_copy_C with PV MMA (different approach)
print(f"\n--- make_tiled_copy_C(smem_copy_atom, pv_mma) ---")
try:
# Use the PV MMA's A-operand thread mapping
# make_tiled_copy uses an M/N layout to determine thread partitioning
store_atom = cute.make_copy_atom(
smem_copy_atom = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
cutlass.BFloat16,
BFloat16,
num_bits_per_copy=128,
)
# Create a tiled copy that uses the QK MMA's C-fragment layout for thread partitioning
# but writes to the PV A-operand SMEM layout
tiled_copy_pv = cute.make_tiled_copy_C(store_atom, pv_mma)
print(f" make_tiled_copy_C(pv_mma) created OK")
tiled_copy_pv = cute.make_tiled_copy_C(smem_copy_atom, pv_mma)
thr_pv = tiled_copy_pv.get_slice(sfw_idx)
sP_2d = cute.group_modes(cute.make_tensor(cute.find(cute.make_layout(p_smem_s.outer), 0), p_smem_s.outer), 0, 3)
tSMEM_PVsP = thr_pv.partition_D(sP_2d)
print(f" tSMEM_PVsP shape: {cute.shape(tSMEM_PVsP)} rank: {len(cute.shape(tSMEM_PVsP))}")
print(f" partition_D(sP_2d) shape: {cute.shape(tSMEM_PVsP)} rank={len(cute.shape(tSMEM_PVsP))}")
except Exception as e:
print(f" make_tiled_copy_C(pv_mma) FAILED: {e}")
# Print the full layout details for manual analysis
print(f"\n--- Detailed layout info ---")
print(f" p_smem_s.outer type: {type(p_smem_s.outer)}")
print(f" p_smem_s.outer: {p_smem_s.outer}")
if hasattr(p_smem_s.outer, 'shape'):
print(f" p_smem_s.outer.shape: {p_smem_s.outer.shape}")
if hasattr(p_smem_s.outer, 'stride'):
print(f" p_smem_s.outer.stride: {p_smem_s.outer.stride}")
# Understand the P matrix dimensions
# P is (128, n_kv) in the attention computation
# But PV MMA sees P as (128, pv_n_tile) A-operand
# For hd=64: P is (128, 64) with n_kv=128 KV positions
# Actually P is (M, K_PV) where K_PV = n_kv for each head
# Wait - PV is P[M, n_kv] @ V[n_kv, hd]
# So P's second dimension is n_kv (= 128 for s_k=128)
print(f"\n--- P matrix dimensions ---")
print(f" P is (128, {s_k}) attention weights matrix")
print(f" PV GEMM: P[128, {s_k}] @ V[{s_k}, {pv_n_tile}] = O[128, {pv_n_tile}]")
print(f" PV A-operand shape: (128, {s_k}) = QK MMA's K dimension")
print(f" But PV MMA tiler K dimension: {pv_mma_tiler[2]}")
print(f" So PV A-operand is (128, {pv_mma_tiler[2]}) per PV sub-tile")
# Key insight: P's dimensions are (128, n_kv) in the QK MMA's K dimension
# The QK C-fragment has shape ((128, 128), 1, 1, 1) for S
# The PV A-operand SMEM has shape that depends on PV MMA tiling
# These represent the SAME data (128 rows, 128 KV positions)
# But they're partitioned differently by the two MMAs
# Try make_tiled_copy (not _C) with explicit M/N layout
print(f"\n--- make_tiled_copy (manual M/N) ---")
try:
# The P matrix is (128, n_kv) = (M=128, N_P=128)
# Use (128, 128) tile for the copy
MN = cute.make_layout((128, 128))
smem_copy_atom2 = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
BFloat16,
num_bits_per_copy=128,
)
# Need to use make_tiled_copy with proper thread arrangement
# This is for the softmax warps (128 threads = 4 warps)
thr_layout = cute.make_layout((4, 32)) # 4 warps × 32 lanes
val_layout = cute.make_layout((1, 1)) # each thread copies 1 element
tiled_manual = cute.make_tiled_copy(smem_copy_atom2, thr_layout, val_layout)
thr_manual = tiled_manual.get_slice(sfw_idx)
tMANUALsP = thr_manual.partition_D(sP_2d)
print(f" partition_D(sP_2d) shape: {cute.shape(tMANUALsP)} rank={len(cute.shape(tMANUALsP))}")
except Exception as e:
print(f" make_tiled_copy (manual) FAILED: {e}")
if __name__ == '__main__':