D1.3: Fix diagnostic - use dummy ptr 0 for shape analysis
This commit is contained in:
@@ -81,16 +81,19 @@ def main():
|
||||
|
||||
# PV A-operand SMEM layout (sP)
|
||||
p_smem_s = utils.sm100.make_smem_layout_a(pv_mma, pv_mma_tiler, BFloat16, 1)
|
||||
# p_smem_s.outer IS a layout, not a shape
|
||||
sP_alloc = cute.make_tensor(cute.find(p_smem_s.outer, 0), p_smem_s.outer)
|
||||
# p_smem_s.outer is the layout. Use it directly for shape queries.
|
||||
# Can't actually allocate SMEM here (that's the SmemAllocator's job in the kernel)
|
||||
# But we can create a tensor with a dummy pointer for shape analysis.
|
||||
sP_alloc = cute.make_tensor(0, p_smem_s.outer)
|
||||
print(f" sP shape: {cute.shape(sP_alloc)}")
|
||||
sP_2d = cute.group_modes(sP_alloc, 0, 3)
|
||||
print(f" sP_2d shape: {cute.shape(sP_2d)} rank={len(cute.shape(sP_2d))}")
|
||||
print(f" sP_2d layout: {sP_2d.layout}")
|
||||
print(f" sP_2d layout: {sP_2d.layout}")
|
||||
|
||||
# PV A-operand fragments
|
||||
tP_tmem = cute.make_tensor(tStS.iterator, p_smem_s.outer) # dummy tmem ptr
|
||||
tP_smem = cute.make_tensor(cute.find(p_smem_s.outer, 0), p_smem_s.outer)
|
||||
tP_tmem = cute.make_tensor(0, p_smem_s.outer) # dummy ptr
|
||||
tP_smem = cute.make_tensor(0, p_smem_s.outer)
|
||||
tOrP_tmem = pv_thr.make_fragment_A(tP_tmem)
|
||||
tCrP_smem = pv_mma.make_fragment_A(tP_smem)
|
||||
tOrP_smem = pv_thr.make_fragment_A(tP_smem)
|
||||
|
||||
Reference in New Issue
Block a user