Fix fused router: plain ints for mma_tiler + @cute.jit pattern

Root cause of previous crash: cutlass.Int32(128) wrapping of mma_inst_shape_mn
caused _unpack_x_tuple to fail in cute.size(tiled_mma.shape_mnk, mode=[2]).

The fused_swiglu kernel uses plain Python ints for mma_tiler_mnk and
mma_inst_shape_mn — NOT cutlass.Int32. Inside @cute.jit, CuTeDSL
auto-converts plain ints to MLIR values. The Int32 wrapping was unnecessary
and actually harmful.

Pattern: same as fused_swiglu.py __call__:
- @cute.jit compiled_fn takes CuTe tensors
- _setup_attributes called inside JIT (needs MLIR context)
- cute.compile at the end
This commit is contained in:
2026-06-01 10:37:15 +00:00
parent 057ae2101e
commit e0f60b9f05

View File

@@ -67,13 +67,10 @@ class Nvfp4FusedRouterKernel:
self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
self.arch = "sm_100"
self.mma_inst_shape_mn = (
cutlass.Int32(mma_tiler_mnk[0]),
cutlass.Int32(mma_tiler_mnk[1]),
)
self.mma_inst_shape_mn = (mma_tiler_mnk[0], mma_tiler_mnk[1])
self.mma_inst_shape_mn_sfb = (
self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1),
cute.round_up(self.mma_inst_shape_mn[1], 128),
mma_tiler_mnk[0] // (2 if self.use_2cta_instrs else 1),
cute.round_up(mma_tiler_mnk[1], 128),
)
# 6-warp specialization (no scheduler warp for dense GEMM)
@@ -111,16 +108,8 @@ class Nvfp4FusedRouterKernel:
mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2])
mma_inst_tile_k = self.mma_tiler_mnk[2] // mma_inst_shape_k
self.mma_tiler = (
self.mma_inst_shape_mn[0],
self.mma_inst_shape_mn[1],
cutlass.Int32(self.mma_tiler_mnk[2]),
)
self.mma_tiler_sfb = (
self.mma_inst_shape_mn_sfb[0],
self.mma_inst_shape_mn_sfb[1],
cutlass.Int32(self.mma_tiler_mnk[2]),
)
self.mma_tiler = (self.mma_tiler_mnk[0], self.mma_tiler_mnk[1], self.mma_tiler_mnk[2])
self.mma_tiler_sfb = (self.mma_inst_shape_mn_sfb[0], self.mma_inst_shape_mn_sfb[1], self.mma_tiler_mnk[2])
self.cta_tile_shape_mnk = (
self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
self.mma_tiler[1],
@@ -133,10 +122,10 @@ class Nvfp4FusedRouterKernel:
)
self.cluster_layout_vmnk = cute.tiled_divide(
cute.make_layout((cutlass.Int32(self.cluster_shape_mn[0]), cutlass.Int32(self.cluster_shape_mn[1]), cutlass.Int32(1))),
cute.make_layout((self.cluster_shape_mn[0], self.cluster_shape_mn[1], 1)),
(tiled_mma.thr_id.shape,))
self.cluster_layout_sfb_vmnk = cute.tiled_divide(
cute.make_layout((cutlass.Int32(self.cluster_shape_mn[0]), cutlass.Int32(self.cluster_shape_mn[1]), cutlass.Int32(1))),
cute.make_layout((self.cluster_shape_mn[0], self.cluster_shape_mn[1], 1)),
(tiled_mma_sfb.thr_id.shape,))
self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2])
@@ -287,67 +276,65 @@ class Nvfp4FusedRouterKernel:
num_N_tiles = (N + cta_n - 1) // cta_n
grid = (num_M_tiles * num_N_tiles, 1, 1)
# Setup tiled MMA and attributes on HOST side (outside JIT)
# Same pattern as fused_swiglu.py __call__
# _setup_attributes calls cute.size(tiled_mma.shape_mnk, mode=[2])
# which requires host-side execution (not inside @cute.jit)
tiled_mma = self._create_tiled_mma(a_dtype, a_major_mode, b_major_mode, sf_dtype)
tiled_mma_sfb = self._create_tiled_mma_sfb(a_dtype, a_major_mode, b_major_mode, sf_dtype)
self._setup_attributes(tiled_mma, tiled_mma_sfb, a_dtype, b_dtype, sf_dtype, c_dtype, c_layout)
@cute.jit
def _compiled_fn(mat_a, mat_b, scale_a, scale_b, mat_c):
# Create tiled MMA and setup inside JIT context
# (same pattern as fused_swiglu.py @cute.jit __call__)
# Plain int mma_tiler values work with cute.size() inside JIT
tiled_mma = self._create_tiled_mma(a_dtype, a_major_mode, b_major_mode, sf_dtype)
tiled_mma_sfb = self._create_tiled_mma_sfb(a_dtype, a_major_mode, b_major_mode, sf_dtype)
self._setup_attributes(tiled_mma, tiled_mma_sfb, a_dtype, b_dtype, sf_dtype, c_dtype, c_layout)
# TMA atoms (host side, same as fused_swiglu)
a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
a_op, mat_a, a_smem_layout, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
# TMA atoms (inside JIT, same as fused_swiglu)
a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
a_op, mat_a, a_smem_layout, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id)
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
b_op, mat_b, b_smem_layout, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id)
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
b_op, mat_b, b_smem_layout, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
sfa_smem_layout = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0))
tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A(
sfa_op, scale_a, sfa_smem_layout, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape,
internal_type=cutlass.Uint64)
sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
sfa_smem_layout = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0))
tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A(
sfa_op, scale_a, sfa_smem_layout, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape,
internal_type=cutlass.Uint64)
sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma.thr_id)
sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0))
tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B(
sfb_op, scale_b, sfb_smem_layout, self.mma_tiler_sfb, tiled_mma_sfb,
self.cluster_layout_sfb_vmnk.shape, internal_type=cutlass.Uint64)
sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma.thr_id)
sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0))
tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B(
sfb_op, scale_b, sfb_smem_layout, self.mma_tiler_sfb, tiled_mma_sfb,
self.cluster_layout_sfb_vmnk.shape, internal_type=cutlass.Uint64)
epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0))
tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
cpasync.CopyBulkTensorTileS2GOp(), mat_c, epi_smem_layout, self.epi_tile)
epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0))
tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
cpasync.CopyBulkTensorTileS2GOp(), mat_c, epi_smem_layout, self.epi_tile)
tile_sched_params = utils.PersistentTileSchedulerParams(
(cutlass.Int32(num_M_tiles), cutlass.Int32(num_N_tiles), cutlass.Int32(1)),
(1, 1, 1))
tile_sched_params = utils.PersistentTileSchedulerParams(
(num_M_tiles, num_N_tiles, 1), (1, 1, 1))
# Launch kernel directly (same as fused_swiglu pattern)
self._kernel(
tiled_mma, tiled_mma_sfb,
tma_atom_a, tma_tensor_a, tma_atom_b, tma_tensor_b,
tma_atom_sfa, tma_tensor_sfa, tma_atom_sfb, tma_tensor_sfb,
tma_atom_c, tma_tensor_c,
self.cluster_layout_vmnk, self.cluster_layout_sfb_vmnk,
self.a_smem_layout_staged, self.b_smem_layout_staged,
self.sfa_smem_layout_staged, self.sfb_smem_layout_staged,
self.c_smem_layout_staged,
self.epi_tile,
tile_sched_params,
M, N, K, gsa, gsb,
).launch(
grid=grid, block=[self.threads_per_cta, 1, 1],
cluster=(*self.cluster_shape_mn, 1),
stream=stream, min_blocks_per_mp=1,
)
self._kernel(
tiled_mma, tiled_mma_sfb,
tma_atom_a, tma_tensor_a, tma_atom_b, tma_tensor_b,
tma_atom_sfa, tma_tensor_sfa, tma_atom_sfb, tma_tensor_sfb,
tma_atom_c, tma_tensor_c,
self.cluster_layout_vmnk, self.cluster_layout_sfb_vmnk,
self.a_smem_layout_staged, self.b_smem_layout_staged,
self.sfa_smem_layout_staged, self.sfb_smem_layout_staged,
self.c_smem_layout_staged,
self.epi_tile,
tile_sched_params,
M, N, K, gsa, gsb,
).launch(
grid=grid, block=[self.threads_per_cta, 1, 1],
cluster=(*self.cluster_shape_mn, 1),
stream=stream, min_blocks_per_mp=1,
)
cute.compile(_compiled_fn, mat_a, mat_b, scale_a, scale_b, mat_c)
# -----------------------------------------------------------------
# GPU kernel
# -----------------------------------------------------------------
@cute.kernel
def _kernel(self, tiled_mma, tiled_mma_sfb,
tma_atom_a, mA_mkl, tma_atom_b, mB_nkl,