From ec8f292112be6b45af4c83f6b79beba0d3f3ccf3 Mon Sep 17 00:00:00 2001 From: biondizzle Date: Mon, 1 Jun 2026 11:03:08 +0000 Subject: [PATCH] Fix: use self.mma_tiler_mnk (full K=64) for SMEM layout computation SFA/SFB SMEM layouts need the full K dimension to compute the correct number of K-tiles. self.mma_tiler has K=1 (placeholder for cute.slice_) which gives 0 K-tiles and zero-dimension SMEM shapes. --- dsv4/kernels/router/nvfp4_fused_router_kernel.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dsv4/kernels/router/nvfp4_fused_router_kernel.py b/dsv4/kernels/router/nvfp4_fused_router_kernel.py index d59caeb7..1635c0b5 100644 --- a/dsv4/kernels/router/nvfp4_fused_router_kernel.py +++ b/dsv4/kernels/router/nvfp4_fused_router_kernel.py @@ -154,13 +154,13 @@ class Nvfp4FusedRouterKernel: # SMEM layouts self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( - tiled_mma, self.mma_tiler, a_dtype, self.num_ab_stage) + tiled_mma, self.mma_tiler_mnk, a_dtype, self.num_ab_stage) self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( - tiled_mma, self.mma_tiler, b_dtype, self.num_ab_stage) + tiled_mma, self.mma_tiler_mnk, b_dtype, self.num_ab_stage) self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( - tiled_mma, self.mma_tiler, self.sf_vec_size, self.num_ab_stage) + tiled_mma, self.mma_tiler_mnk, self.sf_vec_size, self.num_ab_stage) self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( - tiled_mma, self.mma_tiler, self.sf_vec_size, self.num_ab_stage) + tiled_mma, self.mma_tiler_mnk, self.sf_vec_size, self.num_ab_stage) self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi( c_dtype, c_layout, self.epi_tile, self.num_c_stage) @@ -561,7 +561,7 @@ class Nvfp4FusedRouterKernel: # S2T for SFA tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( - tiled_mma, self.mma_tiler, self.sf_vec_size, + tiled_mma, self.mma_tiler_mnk, self.sf_vec_size, cute.slice_(sfa_smem_layout_staged, (None, None, None, 0))) tCtSFA = cute.make_tensor(acc_tmem_ptr, tCtSFA_layout) # S2T for SFB