diff --git a/dsv4/kernels/router/nvfp4_fused_router_kernel.py b/dsv4/kernels/router/nvfp4_fused_router_kernel.py index b2aec19d..4efc965e 100644 --- a/dsv4/kernels/router/nvfp4_fused_router_kernel.py +++ b/dsv4/kernels/router/nvfp4_fused_router_kernel.py @@ -333,8 +333,8 @@ class Nvfp4FusedRouterKernel: heap_acts: cute.struct.Align[cute.struct.MemRange[cutlass.Float32, 4*32*6], 128] sA: cute.struct.Align[cute.struct.MemRange[self.a_dtype, cute.cosize(a_smem_layout_staged.outer)], self.buffer_align_bytes] sB: cute.struct.Align[cute.struct.MemRange[self.b_dtype, cute.cosize(b_smem_layout_staged.outer)], self.buffer_align_bytes] - sSFA: cute.struct.Align[cute.struct.MemRange[self.sf_dtype, cute.cosize(sfa_smem_layout_staged.outer)], self.buffer_align_bytes] - sSFB: cute.struct.Align[cute.struct.MemRange[self.sf_dtype, cute.cosize(sfb_smem_layout_staged.outer)], self.buffer_align_bytes] + sSFA: cute.struct.Align[cute.struct.MemRange[self.sf_dtype, cute.cosize(sfa_smem_layout_staged)], self.buffer_align_bytes] + sSFB: cute.struct.Align[cute.struct.MemRange[self.sf_dtype, cute.cosize(sfb_smem_layout_staged)], self.buffer_align_bytes] smem = utils.SmemAllocator() storage = smem.allocate(SharedStorage) @@ -377,8 +377,9 @@ class Nvfp4FusedRouterKernel: # ============================================================== sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner) sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner) - sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged.outer, swizzle=sfa_smem_layout_staged.inner) - sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged.outer, swizzle=sfb_smem_layout_staged.inner) + # SFA/SFB use blockscaled layouts (plain Layout, no swizzle) + sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged) + sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged) # Multicast masks a_mcast = None; b_mcast = None; sfb_mcast = None