Fix SFA/SFB SMEM: blockscaled layouts are plain Layout (no .outer/.inner swizzle)
This commit is contained in:
@@ -333,8 +333,8 @@ class Nvfp4FusedRouterKernel:
|
||||
heap_acts: cute.struct.Align[cute.struct.MemRange[cutlass.Float32, 4*32*6], 128]
|
||||
sA: cute.struct.Align[cute.struct.MemRange[self.a_dtype, cute.cosize(a_smem_layout_staged.outer)], self.buffer_align_bytes]
|
||||
sB: cute.struct.Align[cute.struct.MemRange[self.b_dtype, cute.cosize(b_smem_layout_staged.outer)], self.buffer_align_bytes]
|
||||
sSFA: cute.struct.Align[cute.struct.MemRange[self.sf_dtype, cute.cosize(sfa_smem_layout_staged.outer)], self.buffer_align_bytes]
|
||||
sSFB: cute.struct.Align[cute.struct.MemRange[self.sf_dtype, cute.cosize(sfb_smem_layout_staged.outer)], self.buffer_align_bytes]
|
||||
sSFA: cute.struct.Align[cute.struct.MemRange[self.sf_dtype, cute.cosize(sfa_smem_layout_staged)], self.buffer_align_bytes]
|
||||
sSFB: cute.struct.Align[cute.struct.MemRange[self.sf_dtype, cute.cosize(sfb_smem_layout_staged)], self.buffer_align_bytes]
|
||||
|
||||
smem = utils.SmemAllocator()
|
||||
storage = smem.allocate(SharedStorage)
|
||||
@@ -377,8 +377,9 @@ class Nvfp4FusedRouterKernel:
|
||||
# ==============================================================
|
||||
sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner)
|
||||
sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner)
|
||||
sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged.outer, swizzle=sfa_smem_layout_staged.inner)
|
||||
sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged.outer, swizzle=sfb_smem_layout_staged.inner)
|
||||
# SFA/SFB use blockscaled layouts (plain Layout, no swizzle)
|
||||
sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged)
|
||||
sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged)
|
||||
|
||||
# Multicast masks
|
||||
a_mcast = None; b_mcast = None; sfb_mcast = None
|
||||
|
||||
Reference in New Issue
Block a user