D1.3: Layout diagnostic - print all QK C-fragment and PV A-operand shapes
This commit is contained in:
235
tests/unit/test_d1_3_layout_diag.py
Normal file
235
tests/unit/test_d1_3_layout_diag.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""
|
||||
D1.3 SMEM-P diagnostic: print ALL relevant shapes and layouts.
|
||||
|
||||
Goal: Understand the exact QK C-fragment partition (source) and
|
||||
PV A-operand SMEM layout (destination) so we can write a proper
|
||||
register→SMEM copy for P values.
|
||||
|
||||
This test ONLY prints shapes. No kernel execution. Pure layout analysis.
|
||||
"""
|
||||
import torch, math
|
||||
import cutlass, cutlass.cute as cute, cutlass.utils as utils
|
||||
from cutlass.cute.nvgpu import tcgen05
|
||||
from cutlass.utils import LayoutEnum
|
||||
import cutlass.torch as ct
|
||||
|
||||
|
||||
def analyze_layouts(hd, s_k=128):
|
||||
print(f"\n{'='*60}")
|
||||
print(f" HEAD_DIM={hd}, s_k={s_k}")
|
||||
print(f"{'='*60}")
|
||||
|
||||
m = 128
|
||||
use_smem_p = hd > 64
|
||||
pv_n_tile = min(hd, 256)
|
||||
|
||||
q = torch.randn(m, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
k = torch.randn(s_k, hd, 1, dtype=torch.bfloat16, device='cuda')
|
||||
v = torch.randn(s_k, hd, dtype=torch.bfloat16, device='cuda')
|
||||
|
||||
# Reconstruct V with FMHA layout
|
||||
v_fmha = cute.make_tensor(
|
||||
v.data_ptr(),
|
||||
cute.make_layout((pv_n_tile, s_k, 1), stride=(1, pv_n_tile, pv_n_tile * s_k)),
|
||||
)
|
||||
|
||||
qk_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
cutlass.BFloat16, cutlass.BFloat16,
|
||||
LayoutEnum.from_tensor(ct.from_dlpack(q)).mma_major_mode(),
|
||||
LayoutEnum.from_tensor(ct.from_dlpack(k)).mma_mn_mode(),
|
||||
cutlass.Float32, tcgen05.CtaGroup.ONE, (128, 128),
|
||||
tcgen05.OperandSource.SMEM,
|
||||
)
|
||||
|
||||
pv_a_major = LayoutEnum.from_tensor(ct.from_dlpack(q)).mma_major_mode() if use_smem_p else cute.nvgpu.OperandMajorMode.K
|
||||
pv_source = tcgen05.OperandSource.SMEM if use_smem_p else tcgen05.OperandSource.TMEM
|
||||
pv_mma = utils.sm100.make_trivial_tiled_mma(
|
||||
cutlass.BFloat16, cutlass.BFloat16,
|
||||
pv_a_major,
|
||||
LayoutEnum.from_tensor(ct.from_dlpack(torch.randn(m, pv_n_tile, 1, dtype=torch.bfloat16, device='cuda'))).mma_major_mode(),
|
||||
cutlass.Float32, tcgen05.CtaGroup.ONE, (128, pv_n_tile),
|
||||
pv_source,
|
||||
)
|
||||
|
||||
# MMA tiler
|
||||
qk_ik = cute.size(qk_mma.shape_mnk, mode=[2])
|
||||
qk_mma_tiler = (128, 128, qk_ik * 4)
|
||||
pv_ik = cute.size(pv_mma.shape_mnk, mode=[2])
|
||||
pv_mma_tiler = (128, pv_n_tile, pv_ik * (128 // pv_ik))
|
||||
|
||||
print(f"\n--- MMA info ---")
|
||||
print(f"QK MMA shape_mnk: {qk_mma.shape_mnk}")
|
||||
print(f"PV MMA shape_mnk: {pv_mma.shape_mnk}")
|
||||
print(f"QK tiler: {qk_mma_tiler}")
|
||||
print(f"PV tiler: {pv_mma_tiler}")
|
||||
print(f"use_smem_p: {use_smem_p}")
|
||||
print(f"PV A source: {pv_source}")
|
||||
|
||||
# QK C-fragment (where S and P live in TMEM)
|
||||
qk_thr = qk_mma.get_slice(0)
|
||||
qk_as = qk_thr.partition_shape_C(qk_mma_tiler[:2])
|
||||
tStS = qk_thr.make_fragment_C(qk_as)
|
||||
print(f"\n--- QK C-fragment (tStS) ---")
|
||||
print(f" qk_as (partition shape): {qk_as}")
|
||||
print(f" tStS shape: {cute.shape(tStS)}")
|
||||
print(f" tStS layout: {tStS.layout}")
|
||||
|
||||
# P in TMEM (subset of S columns)
|
||||
p_cols_fp32 = pv_mma_tiler[2] * 16 // 32 # BF16 width / FP32 width
|
||||
print(f" p_cols_fp32 (PV MMA tiler[2] * 16/32): {p_cols_fp32}")
|
||||
p_cols_fp32_v2 = pv_mma_tiler[2] * cutlass.BFloat16.width // cutlass.Float32.width
|
||||
print(f" p_cols_fp32 (PV MMA tiler[2] * BF16_width/FP32_width): {p_cols_fp32_v2}")
|
||||
|
||||
# PV C-fragment (where O lives in TMEM)
|
||||
pv_thr = pv_mma.get_slice(0)
|
||||
pv_as = pv_thr.partition_shape_C(pv_mma_tiler[:2])
|
||||
tOtO = pv_thr.make_fragment_C(pv_as)
|
||||
print(f"\n--- PV C-fragment (tOtO) ---")
|
||||
print(f" pv_as (partition shape): {pv_as}")
|
||||
print(f" tOtO shape: {cute.shape(tOtO)}")
|
||||
print(f" tOtO layout: {tOtO.layout}")
|
||||
|
||||
# PV A-operand (where P is read from in PV GEMM)
|
||||
# TMEM-P: tOrP from TMEM
|
||||
# SMEM-P: tCrP from SMEM, tOrP from sP
|
||||
p_tmem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, cutlass.BFloat16, 1)
|
||||
p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, cutlass.BFloat16, 1)
|
||||
|
||||
print(f"\n--- PV A-operand layouts ---")
|
||||
print(f" p_tmem_s outer: {p_tmem_s.outer}")
|
||||
print(f" p_tmem_s inner (swizzle): {p_smem_s.inner}")
|
||||
|
||||
# Create sP tensor (SMEM allocation)
|
||||
# We need to allocate sP with the PV A-operand SMEM layout
|
||||
print(f"\n--- PV A-operand SMEM layout (sP) ---")
|
||||
print(f" p_smem_s outer: {p_smem_s.outer}")
|
||||
print(f" p_smem_s outer shape: {cute.shape(cute.make_layout(p_smem_s.outer))}")
|
||||
print(f" p_smem_s inner: {p_smem_s.inner}")
|
||||
|
||||
# What does the MMA A-operand fragment look like for SMEM-P?
|
||||
# This is what the MMA warp reads from sP
|
||||
# We need a dummy sP tensor to create the fragment
|
||||
# Use composition to get the A-operand partition
|
||||
tP_tmem = cute.make_tensor(tStS.iterator, p_tmem_s.outer)
|
||||
tP_smem = cute.make_tensor(cute.find(cute.make_layout(p_smem_s.outer), 0), p_smem_s.outer)
|
||||
|
||||
tOrP_tmem = pv_thr.make_fragment_A(tP_tmem)
|
||||
tCrP_smem = pv_mma.make_fragment_A(tP_smem)
|
||||
tOrP_smem = pv_thr.make_fragment_A(tP_smem)
|
||||
|
||||
print(f"\n--- PV A-operand fragments ---")
|
||||
print(f" tOrP_tmem (TMEM-P read): shape={cute.shape(tOrP_tmem)} layout={tOrP_tmem.layout}")
|
||||
print(f" tCrP_smem (SMEM-P load): shape={cute.shape(tCrP_smem)} layout={tCrP_smem.layout}")
|
||||
print(f" tOrP_smem (SMEM-P read): shape={cute.shape(tOrP_smem)} layout={tOrP_smem.layout}")
|
||||
|
||||
# The softmax warps' view of P
|
||||
# They have P in QK C-fragment layout (tStS), and they need to write to sP
|
||||
# Let's understand the QK C-fragment partition per-thread
|
||||
sfw_idx = 0 # thread 0 in softmax warps
|
||||
|
||||
# TMEM load/store atoms for S/P
|
||||
tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), cutlass.Float32)
|
||||
tiled_tmem_load = tcgen05.make_tmem_copy(tmem_load_atom, tStS)
|
||||
thr_load = tiled_tmem_load.get_slice(sfw_idx)
|
||||
tTMEM_LOADtS = thr_load.partition_S(tStS)
|
||||
cS = cute.make_identity_tensor((128, 128))
|
||||
tScS = qk_thr.partition_C(cS)
|
||||
tTMEM_LOADcS = thr_load.partition_D(tScS)
|
||||
|
||||
print(f"\n--- Softmax thread 0 TMEM load partition ---")
|
||||
print(f" tTMEM_LOADtS shape: {cute.shape(tTMEM_LOADtS)}")
|
||||
print(f" tTMEM_LOADcS shape: {cute.shape(tTMEM_LOADcS)}")
|
||||
print(f" tTMEM_LOADcS layout: {tTMEM_LOADcS.layout}")
|
||||
print(f" => Each softmax thread has {cute.size(tTMEM_LOADcS)} S coordinates")
|
||||
|
||||
# Now try make_tiled_copy_C with qk_mma
|
||||
print(f"\n--- make_tiled_copy_C(qk_mma) attempt ---")
|
||||
try:
|
||||
smem_copy_atom = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
cutlass.BFloat16,
|
||||
num_bits_per_copy=128,
|
||||
)
|
||||
tiled_smem_copy = cute.make_tiled_copy_C(smem_copy_atom, qk_mma)
|
||||
thr_smem_copy = tiled_smem_copy.get_slice(sfw_idx)
|
||||
sP_2d = cute.group_modes(cute.make_tensor(cute.find(cute.make_layout(p_smem_s.outer), 0), p_smem_s.outer), 0, 3)
|
||||
tSMEM_CPYsP = thr_smem_copy.partition_D(sP_2d)
|
||||
print(f" tiled_smem_copy: created OK")
|
||||
print(f" sP_2d shape: {cute.shape(sP_2d)} rank: {len(cute.shape(sP_2d))}")
|
||||
print(f" tSMEM_CPYsP shape: {cute.shape(tSMEM_CPYsP)} rank: {len(cute.shape(tSMEM_CPYsP))}")
|
||||
|
||||
# Try partition_S with the rP register layout
|
||||
# rP is in the same layout as rS (TMEM load partition)
|
||||
rP_layout = tTMEM_LOADcS.layout # coordinate layout
|
||||
rP_bf16 = cute.make_tensor(cute.find(rP_layout, 0), rP_layout)
|
||||
try:
|
||||
tSMEM_CPYrP = thr_smem_copy.partition_S(rP_bf16)
|
||||
print(f" tSMEM_CPYrP (from coord layout) shape: {cute.shape(tSMEM_CPYrP)} rank: {len(cute.shape(tSMEM_CPYrP))}")
|
||||
except Exception as e:
|
||||
print(f" partition_S with coord layout FAILED: {e}")
|
||||
|
||||
# Try with tStS layout (QK C-fragment)
|
||||
try:
|
||||
rP_qk = cute.make_tensor(cute.find(tStS.layout, 0), tStS.layout)
|
||||
tSMEM_CPYrP = thr_smem_copy.partition_S(rP_qk)
|
||||
print(f" tSMEM_CPYrP (from QK C-frag layout) shape: {cute.shape(tSMEM_CPYrP)} rank: {len(cute.shape(tSMEM_CPYrP))}")
|
||||
except Exception as e:
|
||||
print(f" partition_S with QK C-frag layout FAILED: {e}")
|
||||
|
||||
except Exception as e:
|
||||
print(f" make_tiled_copy_C FAILED: {e}")
|
||||
|
||||
# Now try make_tiled_copy (NOT make_tiled_copy_C) with pv_mma
|
||||
print(f"\n--- make_tiled_copy with PV A-operand SMEM layout ---")
|
||||
try:
|
||||
# Use the PV MMA's A-operand thread mapping
|
||||
# make_tiled_copy uses an M/N layout to determine thread partitioning
|
||||
store_atom = cute.make_copy_atom(
|
||||
cute.nvgpu.CopyUniversalOp(),
|
||||
cutlass.BFloat16,
|
||||
num_bits_per_copy=128,
|
||||
)
|
||||
# Create a tiled copy that uses the QK MMA's C-fragment layout for thread partitioning
|
||||
# but writes to the PV A-operand SMEM layout
|
||||
tiled_copy_pv = cute.make_tiled_copy_C(store_atom, pv_mma)
|
||||
print(f" make_tiled_copy_C(pv_mma) created OK")
|
||||
thr_pv = tiled_copy_pv.get_slice(sfw_idx)
|
||||
sP_2d = cute.group_modes(cute.make_tensor(cute.find(cute.make_layout(p_smem_s.outer), 0), p_smem_s.outer), 0, 3)
|
||||
tSMEM_PVsP = thr_pv.partition_D(sP_2d)
|
||||
print(f" tSMEM_PVsP shape: {cute.shape(tSMEM_PVsP)} rank: {len(cute.shape(tSMEM_PVsP))}")
|
||||
except Exception as e:
|
||||
print(f" make_tiled_copy_C(pv_mma) FAILED: {e}")
|
||||
|
||||
# Print the full layout details for manual analysis
|
||||
print(f"\n--- Detailed layout info ---")
|
||||
print(f" p_smem_s.outer type: {type(p_smem_s.outer)}")
|
||||
print(f" p_smem_s.outer: {p_smem_s.outer}")
|
||||
if hasattr(p_smem_s.outer, 'shape'):
|
||||
print(f" p_smem_s.outer.shape: {p_smem_s.outer.shape}")
|
||||
if hasattr(p_smem_s.outer, 'stride'):
|
||||
print(f" p_smem_s.outer.stride: {p_smem_s.outer.stride}")
|
||||
|
||||
# Understand the P matrix dimensions
|
||||
# P is (128, n_kv) in the attention computation
|
||||
# But PV MMA sees P as (128, pv_n_tile) A-operand
|
||||
# For hd=64: P is (128, 64) with n_kv=128 KV positions
|
||||
# Actually P is (M, K_PV) where K_PV = n_kv for each head
|
||||
# Wait - PV is P[M, n_kv] @ V[n_kv, hd]
|
||||
# So P's second dimension is n_kv (= 128 for s_k=128)
|
||||
print(f"\n--- P matrix dimensions ---")
|
||||
print(f" P is (128, {s_k}) attention weights matrix")
|
||||
print(f" PV GEMM: P[128, {s_k}] @ V[{s_k}, {pv_n_tile}] = O[128, {pv_n_tile}]")
|
||||
print(f" PV A-operand shape: (128, {s_k}) = QK MMA's K dimension")
|
||||
print(f" But PV MMA tiler K dimension: {pv_mma_tiler[2]}")
|
||||
print(f" So PV A-operand is (128, {pv_mma_tiler[2]}) per PV sub-tile")
|
||||
|
||||
# Key insight: P's dimensions are (128, n_kv) in the QK MMA's K dimension
|
||||
# The QK C-fragment has shape ((128, 128), 1, 1, 1) for S
|
||||
# The PV A-operand SMEM has shape that depends on PV MMA tiling
|
||||
# These represent the SAME data (128 rows, 128 KV positions)
|
||||
# But they're partitioned differently by the two MMAs
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
analyze_layouts(64) # TMEM-P (works)
|
||||
analyze_layouts(256) # SMEM-P (broken)
|
||||
Reference in New Issue
Block a user