Fix SFA/SFB SMEM: blockscaled layouts are plain Layout (no .outer/.inner swizzle)

This commit is contained in:
2026-06-01 07:14:45 +00:00
parent fcd7680583
commit 57cc20d5ad

View File

@@ -333,8 +333,8 @@ class Nvfp4FusedRouterKernel:
heap_acts: cute.struct.Align[cute.struct.MemRange[cutlass.Float32, 4*32*6], 128]
sA: cute.struct.Align[cute.struct.MemRange[self.a_dtype, cute.cosize(a_smem_layout_staged.outer)], self.buffer_align_bytes]
sB: cute.struct.Align[cute.struct.MemRange[self.b_dtype, cute.cosize(b_smem_layout_staged.outer)], self.buffer_align_bytes]
sSFA: cute.struct.Align[cute.struct.MemRange[self.sf_dtype, cute.cosize(sfa_smem_layout_staged.outer)], self.buffer_align_bytes]
sSFB: cute.struct.Align[cute.struct.MemRange[self.sf_dtype, cute.cosize(sfb_smem_layout_staged.outer)], self.buffer_align_bytes]
sSFA: cute.struct.Align[cute.struct.MemRange[self.sf_dtype, cute.cosize(sfa_smem_layout_staged)], self.buffer_align_bytes]
sSFB: cute.struct.Align[cute.struct.MemRange[self.sf_dtype, cute.cosize(sfb_smem_layout_staged)], self.buffer_align_bytes]
smem = utils.SmemAllocator()
storage = smem.allocate(SharedStorage)
@@ -377,8 +377,9 @@ class Nvfp4FusedRouterKernel:
# ==============================================================
sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner)
sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner)
sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged.outer, swizzle=sfa_smem_layout_staged.inner)
sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged.outer, swizzle=sfb_smem_layout_staged.inner)
# SFA/SFB use blockscaled layouts (plain Layout, no swizzle)
sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged)
sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged)
# Multicast masks
a_mcast = None; b_mcast = None; sfb_mcast = None