D1.3: Skip fragment creation in diagnostic, just print layouts

This commit is contained in:
2026-05-23 22:21:31 +00:00
parent b871e6874b
commit a90fe41b6b

View File

@@ -91,22 +91,16 @@ def main():
print(f" sP_2d layout: {sP_2d.layout}")
print(f" sP_2d layout: {sP_2d.layout}")
# PV A-operand fragments
tP_tmem = cute.make_tensor(0, p_smem_s.outer) # dummy ptr
tP_smem = cute.make_tensor(0, p_smem_s.outer)
tOrP_tmem = pv_thr.make_fragment_A(tP_tmem)
tCrP_smem = pv_mma.make_fragment_A(tP_smem)
tOrP_smem = pv_thr.make_fragment_A(tP_smem)
print(f" tOrP_tmem shape: {cute.shape(tOrP_tmem)}")
print(f" tCrP_smem shape: {cute.shape(tCrP_smem)} layout: {tCrP_smem.layout}")
print(f" tOrP_smem shape: {cute.shape(tOrP_smem)}")
# Slice as kernel does
tOrP_tmem_s = tOrP_tmem[(None,None,None,0)]
tOrP_smem_s = tOrP_smem[(None,None,None,0)]
print(f" tOrP_tmem sliced: {cute.shape(tOrP_tmem_s)}")
print(f" tOrP_smem sliced: {cute.shape(tOrP_smem_s)}")
# PV A-operand fragments — can't create with dummy ptr 0 (no memspace)
# Just print the layout info from the SMEM layout
print(f" p_smem_s.outer (sP layout): {p_smem_s.outer}")
print(f" p_smem_s.inner (swizzle): {p_smem_s.inner}")
# What we need: the PV MMA's A-operand thread partition for sP
# This tells us which threads read which sP elements during PV GEMM
# The softmax warps must WRITE to sP using the same mapping
# (so that the MMA warp can READ using its own partition)
# Softmax TMEM load partition
sfw_idx = 0
tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32)