D1.3: Layout diagnostic - print all QK C-fragment and PV A-operand shapes

This commit is contained in:
2026-05-23 22:14:35 +00:00
parent b59aca4655
commit c3a7c30f20

View File

@@ -0,0 +1,235 @@
"""
D1.3 SMEM-P diagnostic: print ALL relevant shapes and layouts.
Goal: Understand the exact QK C-fragment partition (source) and
PV A-operand SMEM layout (destination) so we can write a proper
register→SMEM copy for P values.
This test ONLY prints shapes. No kernel execution. Pure layout analysis.
"""
import torch, math
import cutlass, cutlass.cute as cute, cutlass.utils as utils
from cutlass.cute.nvgpu import tcgen05
from cutlass.utils import LayoutEnum
import cutlass.torch as ct
def analyze_layouts(hd, s_k=128):
print(f"\n{'='*60}")
print(f" HEAD_DIM={hd}, s_k={s_k}")
print(f"{'='*60}")
m = 128
use_smem_p = hd > 64
pv_n_tile = min(hd, 256)
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda')
v = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda')
# Reconstruct V with FMHA layout
v_fmha = cute.make_tensor(
v.data_ptr(),
cute.make_layout((pv_n_tile, s_k, 1), stride=(1, pv_n_tile, pv_n_tile * s_k)),
)
qk_mma = utils.sm100.make_trivial_tiled_mma(
cutlass.BFloat16, cutlass.BFloat16,
LayoutEnum.from_tensor(ct.from_dlpack(q)).mma_major_mode(),
LayoutEnum.from_tensor(ct.from_dlpack(k)).mma_mn_mode(),
cutlass.Float32, tcgen05.CtaGroup.ONE, (128, 128),
tcgen05.OperandSource.SMEM,
)
pv_a_major = LayoutEnum.from_tensor(ct.from_dlpack(q)).mma_major_mode() if use_smem_p else cute.nvgpu.OperandMajorMode.K
pv_source = tcgen05.OperandSource.SMEM if use_smem_p else tcgen05.OperandSource.TMEM
pv_mma = utils.sm100.make_trivial_tiled_mma(
cutlass.BFloat16, cutlass.BFloat16,
pv_a_major,
LayoutEnum.from_tensor(ct.from_dlpack(torch.randn(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda'))).mma_major_mode(),
cutlass.Float32, tcgen05.CtaGroup.ONE, (128, pv_n_tile),
pv_source,
)
# MMA tiler
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
qk_mma_tiler = (128, 128, qk_ik * 4)
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
pv_mma_tiler = (128, pv_n_tile, pv_ik * (128 // pv_ik))
print(f"\n--- MMA info ---")
print(f"QK MMA shape_mnk: {qk_mma.shape_mnk}")
print(f"PV MMA shape_mnk: {pv_mma.shape_mnk}")
print(f"QK tiler: {qk_mma_tiler}")
print(f"PV tiler: {pv_mma_tiler}")
print(f"use_smem_p: {use_smem_p}")
print(f"PV A source: {pv_source}")
# QK C-fragment (where S and P live in TMEM)
qk_thr = qk_mma.get_slice(0)
qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2])
tStS = qk_thr.make_fragment_C(qk_as)
print(f"\n--- QK C-fragment (tStS) ---")
print(f" qk_as (partition shape): {qk_as}")
print(f" tStS shape: {cute.shape(tStS)}")
print(f" tStS layout: {tStS.layout}")
# P in TMEM (subset of S columns)
p_cols_fp32 = pv_mma_tiler[2] * 16 // 32 # BF16 width / FP32 width
print(f" p_cols_fp32 (PV MMA tiler[2] * 16/32): {p_cols_fp32}")
p_cols_fp32_v2 = pv_mma_tiler[2] * cutlass.BFloat16.width // cutlass.Float32.width
print(f" p_cols_fp32 (PV MMA tiler[2] * BF16_width/FP32_width): {p_cols_fp32_v2}")
# PV C-fragment (where O lives in TMEM)
pv_thr = pv_mma.get_slice(0)
pv_as = pv_thr.partition_shape_C(pv_mma_tiler[:2])
tOtO = pv_thr.make_fragment_C(pv_as)
print(f"\n--- PV C-fragment (tOtO) ---")
print(f" pv_as (partition shape): {pv_as}")
print(f" tOtO shape: {cute.shape(tOtO)}")
print(f" tOtO layout: {tOtO.layout}")
# PV A-operand (where P is read from in PV GEMM)
# TMEM-P: tOrP from TMEM
# SMEM-P: tCrP from SMEM, tOrP from sP
p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, cutlass.BFloat16, 1)
p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, cutlass.BFloat16, 1)
print(f"\n--- PV A-operand layouts ---")
print(f" p_tmem_s outer: {p_tmem_s.outer}")
print(f" p_tmem_s inner (swizzle): {p_smem_s.inner}")
# Create sP tensor (SMEM allocation)
# We need to allocate sP with the PV A-operand SMEM layout
print(f"\n--- PV A-operand SMEM layout (sP) ---")
print(f" p_smem_s outer: {p_smem_s.outer}")
print(f" p_smem_s outer shape: {cute.shape(cute.make_layout(p_smem_s.outer))}")
print(f" p_smem_s inner: {p_smem_s.inner}")
# What does the MMA A-operand fragment look like for SMEM-P?
# This is what the MMA warp reads from sP
# We need a dummy sP tensor to create the fragment
# Use composition to get the A-operand partition
tP_tmem = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
tP_smem = cute.make_tensor(cute.find(cute.make_layout(p_smem_s.outer), 0), p_smem_s.outer)
tOrP_tmem = pv_thr.make_fragment_A(tP_tmem)
tCrP_smem = pv_mma.make_fragment_A(tP_smem)
tOrP_smem = pv_thr.make_fragment_A(tP_smem)
print(f"\n--- PV A-operand fragments ---")
print(f" tOrP_tmem (TMEM-P read): shape={cute.shape(tOrP_tmem)} layout={tOrP_tmem.layout}")
print(f" tCrP_smem (SMEM-P load): shape={cute.shape(tCrP_smem)} layout={tCrP_smem.layout}")
print(f" tOrP_smem (SMEM-P read): shape={cute.shape(tOrP_smem)} layout={tOrP_smem.layout}")
# The softmax warps' view of P
# They have P in QK C-fragment layout (tStS), and they need to write to sP
# Let's understand the QK C-fragment partition per-thread
sfw_idx = 0 # thread 0 in softmax warps
# TMEM load/store atoms for S/P
tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), cutlass.Float32)
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS)
thr_load = tiled_tmem_load.get_slice(sfw_idx)
tTMEM_LOADtS = thr_load.partition_S(tStS)
cS = cute.make_identity_tensor((128, 128))
tScS = qk_thr.partition_C(cS)
tTMEM_LOADcS = thr_load.partition_D(tScS)
print(f"\n--- Softmax thread 0 TMEM load partition ---")
print(f" tTMEM_LOADtS shape: {cute.shape(tTMEM_LOADtS)}")
print(f" tTMEM_LOADcS shape: {cute.shape(tTMEM_LOADcS)}")
print(f" tTMEM_LOADcS layout: {tTMEM_LOADcS.layout}")
print(f" => Each softmax thread has {cute.size(tTMEM_LOADcS)} S coordinates")
# Now try make_tiled_copy_C with qk_mma
print(f"\n--- make_tiled_copy_C(qk_mma) attempt ---")
try:
smem_copy_atom = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
cutlass.BFloat16,
num_bits_per_copy=128,
)
tiled_smem_copy = cute.make_tiled_copy_C(smem_copy_atom, qk_mma)
thr_smem_copy = tiled_smem_copy.get_slice(sfw_idx)
sP_2d = cute.group_modes(cute.make_tensor(cute.find(cute.make_layout(p_smem_s.outer), 0), p_smem_s.outer), 0, 3)
tSMEM_CPYsP = thr_smem_copy.partition_D(sP_2d)
print(f" tiled_smem_copy: created OK")
print(f" sP_2d shape: {cute.shape(sP_2d)} rank: {len(cute.shape(sP_2d))}")
print(f" tSMEM_CPYsP shape: {cute.shape(tSMEM_CPYsP)} rank: {len(cute.shape(tSMEM_CPYsP))}")
# Try partition_S with the rP register layout
# rP is in the same layout as rS (TMEM load partition)
rP_layout = tTMEM_LOADcS.layout # coordinate layout
rP_bf16 = cute.make_tensor(cute.find(rP_layout, 0), rP_layout)
try:
tSMEM_CPYrP = thr_smem_copy.partition_S(rP_bf16)
print(f" tSMEM_CPYrP (from coord layout) shape: {cute.shape(tSMEM_CPYrP)} rank: {len(cute.shape(tSMEM_CPYrP))}")
except Exception as e:
print(f" partition_S with coord layout FAILED: {e}")
# Try with tStS layout (QK C-fragment)
try:
rP_qk = cute.make_tensor(cute.find(tStS.layout, 0), tStS.layout)
tSMEM_CPYrP = thr_smem_copy.partition_S(rP_qk)
print(f" tSMEM_CPYrP (from QK C-frag layout) shape: {cute.shape(tSMEM_CPYrP)} rank: {len(cute.shape(tSMEM_CPYrP))}")
except Exception as e:
print(f" partition_S with QK C-frag layout FAILED: {e}")
except Exception as e:
print(f" make_tiled_copy_C FAILED: {e}")
# Now try make_tiled_copy (NOT make_tiled_copy_C) with pv_mma
print(f"\n--- make_tiled_copy with PV A-operand SMEM layout ---")
try:
# Use the PV MMA's A-operand thread mapping
# make_tiled_copy uses an M/N layout to determine thread partitioning
store_atom = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(),
cutlass.BFloat16,
num_bits_per_copy=128,
)
# Create a tiled copy that uses the QK MMA's C-fragment layout for thread partitioning
# but writes to the PV A-operand SMEM layout
tiled_copy_pv = cute.make_tiled_copy_C(store_atom, pv_mma)
print(f" make_tiled_copy_C(pv_mma) created OK")
thr_pv = tiled_copy_pv.get_slice(sfw_idx)
sP_2d = cute.group_modes(cute.make_tensor(cute.find(cute.make_layout(p_smem_s.outer), 0), p_smem_s.outer), 0, 3)
tSMEM_PVsP = thr_pv.partition_D(sP_2d)
print(f" tSMEM_PVsP shape: {cute.shape(tSMEM_PVsP)} rank: {len(cute.shape(tSMEM_PVsP))}")
except Exception as e:
print(f" make_tiled_copy_C(pv_mma) FAILED: {e}")
# Print the full layout details for manual analysis
print(f"\n--- Detailed layout info ---")
print(f" p_smem_s.outer type: {type(p_smem_s.outer)}")
print(f" p_smem_s.outer: {p_smem_s.outer}")
if hasattr(p_smem_s.outer, 'shape'):
print(f" p_smem_s.outer.shape: {p_smem_s.outer.shape}")
if hasattr(p_smem_s.outer, 'stride'):
print(f" p_smem_s.outer.stride: {p_smem_s.outer.stride}")
# Understand the P matrix dimensions
# P is (128, n_kv) in the attention computation
# But PV MMA sees P as (128, pv_n_tile) A-operand
# For hd=64: P is (128, 64) with n_kv=128 KV positions
# Actually P is (M, K_PV) where K_PV = n_kv for each head
# Wait - PV is P[M, n_kv] @ V[n_kv, hd]
# So P's second dimension is n_kv (= 128 for s_k=128)
print(f"\n--- P matrix dimensions ---")
print(f" P is (128, {s_k}) attention weights matrix")
print(f" PV GEMM: P[128, {s_k}] @ V[{s_k}, {pv_n_tile}] = O[128, {pv_n_tile}]")
print(f" PV A-operand shape: (128, {s_k}) = QK MMA's K dimension")
print(f" But PV MMA tiler K dimension: {pv_mma_tiler[2]}")
print(f" So PV A-operand is (128, {pv_mma_tiler[2]}) per PV sub-tile")
# Key insight: P's dimensions are (128, n_kv) in the QK MMA's K dimension
# The QK C-fragment has shape ((128, 128), 1, 1, 1) for S
# The PV A-operand SMEM has shape that depends on PV MMA tiling
# These represent the SAME data (128 rows, 128 KV positions)
# But they're partitioned differently by the two MMAs
if __name__ == '__main__':
analyze_layouts(64) # TMEM-P (works)
analyze_layouts(256) # SMEM-P (broken)