Fix fused router: plain ints for mma_tiler + @cute.jit pattern
Root cause of previous crash: cutlass.Int32(128) wrapping of mma_inst_shape_mn caused _unpack_x_tuple to fail in cute.size(tiled_mma.shape_mnk, mode=[2]). The fused_swiglu kernel uses plain Python ints for mma_tiler_mnk and mma_inst_shape_mn — NOT cutlass.Int32. Inside @cute.jit, CuTeDSL auto-converts plain ints to MLIR values. The Int32 wrapping was unnecessary and actually harmful. Pattern: same as fused_swiglu.py __call__: - @cute.jit compiled_fn takes CuTe tensors - _setup_attributes called inside JIT (needs MLIR context) - cute.compile at the end
This commit is contained in:
@@ -67,13 +67,10 @@ class Nvfp4FusedRouterKernel:
|
||||
self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE
|
||||
self.arch = "sm_100"
|
||||
|
||||
self.mma_inst_shape_mn = (
|
||||
cutlass.Int32(mma_tiler_mnk[0]),
|
||||
cutlass.Int32(mma_tiler_mnk[1]),
|
||||
)
|
||||
self.mma_inst_shape_mn = (mma_tiler_mnk[0], mma_tiler_mnk[1])
|
||||
self.mma_inst_shape_mn_sfb = (
|
||||
self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1),
|
||||
cute.round_up(self.mma_inst_shape_mn[1], 128),
|
||||
mma_tiler_mnk[0] // (2 if self.use_2cta_instrs else 1),
|
||||
cute.round_up(mma_tiler_mnk[1], 128),
|
||||
)
|
||||
|
||||
# 6-warp specialization (no scheduler warp for dense GEMM)
|
||||
@@ -111,16 +108,8 @@ class Nvfp4FusedRouterKernel:
|
||||
mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2])
|
||||
mma_inst_tile_k = self.mma_tiler_mnk[2] // mma_inst_shape_k
|
||||
|
||||
self.mma_tiler = (
|
||||
self.mma_inst_shape_mn[0],
|
||||
self.mma_inst_shape_mn[1],
|
||||
cutlass.Int32(self.mma_tiler_mnk[2]),
|
||||
)
|
||||
self.mma_tiler_sfb = (
|
||||
self.mma_inst_shape_mn_sfb[0],
|
||||
self.mma_inst_shape_mn_sfb[1],
|
||||
cutlass.Int32(self.mma_tiler_mnk[2]),
|
||||
)
|
||||
self.mma_tiler = (self.mma_tiler_mnk[0], self.mma_tiler_mnk[1], self.mma_tiler_mnk[2])
|
||||
self.mma_tiler_sfb = (self.mma_inst_shape_mn_sfb[0], self.mma_inst_shape_mn_sfb[1], self.mma_tiler_mnk[2])
|
||||
self.cta_tile_shape_mnk = (
|
||||
self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape),
|
||||
self.mma_tiler[1],
|
||||
@@ -133,10 +122,10 @@ class Nvfp4FusedRouterKernel:
|
||||
)
|
||||
|
||||
self.cluster_layout_vmnk = cute.tiled_divide(
|
||||
cute.make_layout((cutlass.Int32(self.cluster_shape_mn[0]), cutlass.Int32(self.cluster_shape_mn[1]), cutlass.Int32(1))),
|
||||
cute.make_layout((self.cluster_shape_mn[0], self.cluster_shape_mn[1], 1)),
|
||||
(tiled_mma.thr_id.shape,))
|
||||
self.cluster_layout_sfb_vmnk = cute.tiled_divide(
|
||||
cute.make_layout((cutlass.Int32(self.cluster_shape_mn[0]), cutlass.Int32(self.cluster_shape_mn[1]), cutlass.Int32(1))),
|
||||
cute.make_layout((self.cluster_shape_mn[0], self.cluster_shape_mn[1], 1)),
|
||||
(tiled_mma_sfb.thr_id.shape,))
|
||||
|
||||
self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2])
|
||||
@@ -287,67 +276,65 @@ class Nvfp4FusedRouterKernel:
|
||||
num_N_tiles = (N + cta_n - 1) // cta_n
|
||||
grid = (num_M_tiles * num_N_tiles, 1, 1)
|
||||
|
||||
# Setup tiled MMA and attributes on HOST side (outside JIT)
|
||||
# Same pattern as fused_swiglu.py __call__
|
||||
# _setup_attributes calls cute.size(tiled_mma.shape_mnk, mode=[2])
|
||||
# which requires host-side execution (not inside @cute.jit)
|
||||
tiled_mma = self._create_tiled_mma(a_dtype, a_major_mode, b_major_mode, sf_dtype)
|
||||
tiled_mma_sfb = self._create_tiled_mma_sfb(a_dtype, a_major_mode, b_major_mode, sf_dtype)
|
||||
self._setup_attributes(tiled_mma, tiled_mma_sfb, a_dtype, b_dtype, sf_dtype, c_dtype, c_layout)
|
||||
@cute.jit
|
||||
def _compiled_fn(mat_a, mat_b, scale_a, scale_b, mat_c):
|
||||
# Create tiled MMA and setup inside JIT context
|
||||
# (same pattern as fused_swiglu.py @cute.jit __call__)
|
||||
# Plain int mma_tiler values work with cute.size() inside JIT
|
||||
tiled_mma = self._create_tiled_mma(a_dtype, a_major_mode, b_major_mode, sf_dtype)
|
||||
tiled_mma_sfb = self._create_tiled_mma_sfb(a_dtype, a_major_mode, b_major_mode, sf_dtype)
|
||||
self._setup_attributes(tiled_mma, tiled_mma_sfb, a_dtype, b_dtype, sf_dtype, c_dtype, c_layout)
|
||||
|
||||
# TMA atoms (host side, same as fused_swiglu)
|
||||
a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
|
||||
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
|
||||
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
a_op, mat_a, a_smem_layout, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
|
||||
# TMA atoms (inside JIT, same as fused_swiglu)
|
||||
a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
|
||||
a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
|
||||
tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
a_op, mat_a, a_smem_layout, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
|
||||
|
||||
b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id)
|
||||
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
|
||||
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
b_op, mat_b, b_smem_layout, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
|
||||
b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id)
|
||||
b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
|
||||
tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
b_op, mat_b, b_smem_layout, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape)
|
||||
|
||||
sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
|
||||
sfa_smem_layout = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0))
|
||||
tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
sfa_op, scale_a, sfa_smem_layout, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape,
|
||||
internal_type=cutlass.Uint64)
|
||||
sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
|
||||
sfa_smem_layout = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0))
|
||||
tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A(
|
||||
sfa_op, scale_a, sfa_smem_layout, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape,
|
||||
internal_type=cutlass.Uint64)
|
||||
|
||||
sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma.thr_id)
|
||||
sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0))
|
||||
tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
sfb_op, scale_b, sfb_smem_layout, self.mma_tiler_sfb, tiled_mma_sfb,
|
||||
self.cluster_layout_sfb_vmnk.shape, internal_type=cutlass.Uint64)
|
||||
sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma.thr_id)
|
||||
sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0))
|
||||
tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B(
|
||||
sfb_op, scale_b, sfb_smem_layout, self.mma_tiler_sfb, tiled_mma_sfb,
|
||||
self.cluster_layout_sfb_vmnk.shape, internal_type=cutlass.Uint64)
|
||||
|
||||
epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0))
|
||||
tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
|
||||
cpasync.CopyBulkTensorTileS2GOp(), mat_c, epi_smem_layout, self.epi_tile)
|
||||
epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0))
|
||||
tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom(
|
||||
cpasync.CopyBulkTensorTileS2GOp(), mat_c, epi_smem_layout, self.epi_tile)
|
||||
|
||||
tile_sched_params = utils.PersistentTileSchedulerParams(
|
||||
(cutlass.Int32(num_M_tiles), cutlass.Int32(num_N_tiles), cutlass.Int32(1)),
|
||||
(1, 1, 1))
|
||||
tile_sched_params = utils.PersistentTileSchedulerParams(
|
||||
(num_M_tiles, num_N_tiles, 1), (1, 1, 1))
|
||||
|
||||
# Launch kernel directly (same as fused_swiglu pattern)
|
||||
self._kernel(
|
||||
tiled_mma, tiled_mma_sfb,
|
||||
tma_atom_a, tma_tensor_a, tma_atom_b, tma_tensor_b,
|
||||
tma_atom_sfa, tma_tensor_sfa, tma_atom_sfb, tma_tensor_sfb,
|
||||
tma_atom_c, tma_tensor_c,
|
||||
self.cluster_layout_vmnk, self.cluster_layout_sfb_vmnk,
|
||||
self.a_smem_layout_staged, self.b_smem_layout_staged,
|
||||
self.sfa_smem_layout_staged, self.sfb_smem_layout_staged,
|
||||
self.c_smem_layout_staged,
|
||||
self.epi_tile,
|
||||
tile_sched_params,
|
||||
M, N, K, gsa, gsb,
|
||||
).launch(
|
||||
grid=grid, block=[self.threads_per_cta, 1, 1],
|
||||
cluster=(*self.cluster_shape_mn, 1),
|
||||
stream=stream, min_blocks_per_mp=1,
|
||||
)
|
||||
self._kernel(
|
||||
tiled_mma, tiled_mma_sfb,
|
||||
tma_atom_a, tma_tensor_a, tma_atom_b, tma_tensor_b,
|
||||
tma_atom_sfa, tma_tensor_sfa, tma_atom_sfb, tma_tensor_sfb,
|
||||
tma_atom_c, tma_tensor_c,
|
||||
self.cluster_layout_vmnk, self.cluster_layout_sfb_vmnk,
|
||||
self.a_smem_layout_staged, self.b_smem_layout_staged,
|
||||
self.sfa_smem_layout_staged, self.sfb_smem_layout_staged,
|
||||
self.c_smem_layout_staged,
|
||||
self.epi_tile,
|
||||
tile_sched_params,
|
||||
M, N, K, gsa, gsb,
|
||||
).launch(
|
||||
grid=grid, block=[self.threads_per_cta, 1, 1],
|
||||
cluster=(*self.cluster_shape_mn, 1),
|
||||
stream=stream, min_blocks_per_mp=1,
|
||||
)
|
||||
|
||||
cute.compile(_compiled_fn, mat_a, mat_b, scale_a, scale_b, mat_c)
|
||||
|
||||
# -----------------------------------------------------------------
|
||||
# GPU kernel
|
||||
# -----------------------------------------------------------------
|
||||
@cute.kernel
|
||||
def _kernel(self, tiled_mma, tiled_mma_sfb,
|
||||
tma_atom_a, mA_mkl, tma_atom_b, mB_nkl,
|
||||
|
||||
Reference in New Issue
Block a user