Fix: use self.mma_tiler_mnk (full K=64) for SMEM layout computation
SFA/SFB SMEM layouts need the full K dimension to compute the correct number of K-tiles. self.mma_tiler has K=1 (placeholder for cute.slice_) which gives 0 K-tiles and zero-dimension SMEM shapes.
This commit is contained in:
@@ -154,13 +154,13 @@ class Nvfp4FusedRouterKernel:
|
||||
|
||||
# SMEM layouts
|
||||
self.a_smem_layout_staged = sm100_utils.make_smem_layout_a(
|
||||
tiled_mma, self.mma_tiler, a_dtype, self.num_ab_stage)
|
||||
tiled_mma, self.mma_tiler_mnk, a_dtype, self.num_ab_stage)
|
||||
self.b_smem_layout_staged = sm100_utils.make_smem_layout_b(
|
||||
tiled_mma, self.mma_tiler, b_dtype, self.num_ab_stage)
|
||||
tiled_mma, self.mma_tiler_mnk, b_dtype, self.num_ab_stage)
|
||||
self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa(
|
||||
tiled_mma, self.mma_tiler, self.sf_vec_size, self.num_ab_stage)
|
||||
tiled_mma, self.mma_tiler_mnk, self.sf_vec_size, self.num_ab_stage)
|
||||
self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb(
|
||||
tiled_mma, self.mma_tiler, self.sf_vec_size, self.num_ab_stage)
|
||||
tiled_mma, self.mma_tiler_mnk, self.sf_vec_size, self.num_ab_stage)
|
||||
self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi(
|
||||
c_dtype, c_layout, self.epi_tile, self.num_c_stage)
|
||||
|
||||
@@ -561,7 +561,7 @@ class Nvfp4FusedRouterKernel:
|
||||
|
||||
# S2T for SFA
|
||||
tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa(
|
||||
tiled_mma, self.mma_tiler, self.sf_vec_size,
|
||||
tiled_mma, self.mma_tiler_mnk, self.sf_vec_size,
|
||||
cute.slice_(sfa_smem_layout_staged, (None, None, None, 0)))
|
||||
tCtSFA = cute.make_tensor(acc_tmem_ptr, tCtSFA_layout)
|
||||
# S2T for SFB
|
||||
|
||||
Reference in New Issue
Block a user