Fix: use self.mma_tiler_mnk (full K=64) for SMEM layout computation

SFA/SFB SMEM layouts need the full K dimension to compute the correct
number of K-tiles. self.mma_tiler has K=1 (placeholder for cute.slice_)
which gives 0 K-tiles and zero-dimension SMEM shapes.
This commit is contained in:
2026-06-01 11:03:08 +00:00
parent 44fb9b6c00
commit ec8f292112

View File

@@ -154,13 +154,13 @@ class Nvfp4FusedRouterKernel:
# SMEM layouts
self.a_smem_layout_staged = sm100_utils.make_smem_layout_a(
tiled_mma, self.mma_tiler, a_dtype, self.num_ab_stage)
tiled_mma, self.mma_tiler_mnk, a_dtype, self.num_ab_stage)
self.b_smem_layout_staged = sm100_utils.make_smem_layout_b(
tiled_mma, self.mma_tiler, b_dtype, self.num_ab_stage)
tiled_mma, self.mma_tiler_mnk, b_dtype, self.num_ab_stage)
self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa(
tiled_mma, self.mma_tiler, self.sf_vec_size, self.num_ab_stage)
tiled_mma, self.mma_tiler_mnk, self.sf_vec_size, self.num_ab_stage)
self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb(
tiled_mma, self.mma_tiler, self.sf_vec_size, self.num_ab_stage)
tiled_mma, self.mma_tiler_mnk, self.sf_vec_size, self.num_ab_stage)
self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi(
c_dtype, c_layout, self.epi_tile, self.num_c_stage)
@@ -561,7 +561,7 @@ class Nvfp4FusedRouterKernel:
# S2T for SFA
tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa(
tiled_mma, self.mma_tiler, self.sf_vec_size,
tiled_mma, self.mma_tiler_mnk, self.sf_vec_size,
cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)))
tCtSFA = cute.make_tensor(acc_tmem_ptr, tCtSFA_layout)
# S2T for SFB