D1.3: Skip fragment creation in diagnostic, just print layouts
This commit is contained in:
@@ -91,22 +91,16 @@ def main():
|
||||
print(f" sP_2d layout: {sP_2d.layout}")
|
||||
print(f" sP_2d layout: {sP_2d.layout}")
|
||||
|
||||
# PV A-operand fragments
|
||||
tP_tmem = cute.make_tensor(0, p_smem_s.outer) # dummy ptr
|
||||
tP_smem = cute.make_tensor(0, p_smem_s.outer)
|
||||
tOrP_tmem = pv_thr.make_fragment_A(tP_tmem)
|
||||
tCrP_smem = pv_mma.make_fragment_A(tP_smem)
|
||||
tOrP_smem = pv_thr.make_fragment_A(tP_smem)
|
||||
print(f" tOrP_tmem shape: {cute.shape(tOrP_tmem)}")
|
||||
print(f" tCrP_smem shape: {cute.shape(tCrP_smem)} layout: {tCrP_smem.layout}")
|
||||
print(f" tOrP_smem shape: {cute.shape(tOrP_smem)}")
|
||||
|
||||
# Slice as kernel does
|
||||
tOrP_tmem_s = tOrP_tmem[(None,None,None,0)]
|
||||
tOrP_smem_s = tOrP_smem[(None,None,None,0)]
|
||||
print(f" tOrP_tmem sliced: {cute.shape(tOrP_tmem_s)}")
|
||||
print(f" tOrP_smem sliced: {cute.shape(tOrP_smem_s)}")
|
||||
|
||||
# PV A-operand fragments — can't create with dummy ptr 0 (no memspace)
|
||||
# Just print the layout info from the SMEM layout
|
||||
print(f" p_smem_s.outer (sP layout): {p_smem_s.outer}")
|
||||
print(f" p_smem_s.inner (swizzle): {p_smem_s.inner}")
|
||||
|
||||
# What we need: the PV MMA's A-operand thread partition for sP
|
||||
# This tells us which threads read which sP elements during PV GEMM
|
||||
# The softmax warps must WRITE to sP using the same mapping
|
||||
# (so that the MMA warp can READ using its own partition)
|
||||
|
||||
# Softmax TMEM load partition
|
||||
sfw_idx = 0
|
||||
tmem_load_atom = cute.make_copy_atom(tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), Float32)
|
||||
|
||||
Reference in New Issue
Block a user