diff --git a/dsv4/kernels/router/nvfp4_fused_router_kernel.py b/dsv4/kernels/router/nvfp4_fused_router_kernel.py index d59caeb7..1635c0b5 100644 --- a/dsv4/kernels/router/nvfp4_fused_router_kernel.py +++ b/dsv4/kernels/router/nvfp4_fused_router_kernel.py @@ -154,13 +154,13 @@ class Nvfp4FusedRouterKernel: # SMEM layouts self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( - tiled_mma, self.mma_tiler, a_dtype, self.num_ab_stage) + tiled_mma, self.mma_tiler_mnk, a_dtype, self.num_ab_stage) self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( - tiled_mma, self.mma_tiler, b_dtype, self.num_ab_stage) + tiled_mma, self.mma_tiler_mnk, b_dtype, self.num_ab_stage) self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( - tiled_mma, self.mma_tiler, self.sf_vec_size, self.num_ab_stage) + tiled_mma, self.mma_tiler_mnk, self.sf_vec_size, self.num_ab_stage) self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( - tiled_mma, self.mma_tiler, self.sf_vec_size, self.num_ab_stage) + tiled_mma, self.mma_tiler_mnk, self.sf_vec_size, self.num_ab_stage) self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi( c_dtype, c_layout, self.epi_tile, self.num_c_stage) @@ -561,7 +561,7 @@ class Nvfp4FusedRouterKernel: # S2T for SFA tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( - tiled_mma, self.mma_tiler, self.sf_vec_size, + tiled_mma, self.mma_tiler_mnk, self.sf_vec_size, cute.slice_(sfa_smem_layout_staged, (None, None, None, 0))) tCtSFA = cute.make_tensor(acc_tmem_ptr, tCtSFA_layout) # S2T for SFB