Fix: mma_tiler must use CuTe Ints for static layout construction
This commit is contained in:
@@ -220,17 +220,13 @@ class Nvfp4FusedRouterKernel:
|
||||
self.a_major_mode = utils.LayoutEnum.from_tensor(mat_a).mma_major_mode()
|
||||
self.b_major_mode = utils.LayoutEnum.from_tensor(mat_b).mma_major_mode()
|
||||
|
||||
mma_inst_shape_k = 32
|
||||
mma_inst_tile_k = 4
|
||||
self.mma_tiler = (*self.mma_tiler_mn, mma_inst_shape_k * mma_inst_tile_k)
|
||||
# Set mma_tiler with CuTe Ints so that blockscaled layout construction
|
||||
# produces static (not dynamic) dimensions. K=128 FP4 elements per K-tile.
|
||||
self.mma_tiler = (cutlass.Int32(self.mma_tiler_mn[0]), cutlass.Int32(self.mma_tiler_mn[1]), cutlass.Int32(128))
|
||||
|
||||
self._setup_attributes()
|
||||
tiled_mma = self._tiled_mma
|
||||
tiled_mma_sfb = self._tiled_mma_sfb
|
||||
|
||||
# Ensure mma_tiler contains CuTe Ints (not Python ints)
|
||||
self.mma_tiler = (cutlass.Int32(self.mma_tiler[0]), cutlass.Int32(self.mma_tiler[1]), cutlass.Int32(self.mma_tiler[2]))
|
||||
self.mma_tiler_sfb = (cutlass.Int32(self.mma_tiler_sfb[0]), cutlass.Int32(self.mma_tiler_sfb[1]), cutlass.Int32(self.mma_tiler_sfb[2]))
|
||||
atom_thr_size = cute.size(tiled_mma.thr_id.shape)
|
||||
|
||||
a_smem_0 = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
|
||||
|
||||
Reference in New Issue
Block a user