diff --git a/dsv4/kernels/router/nvfp4_fused_router_kernel.py b/dsv4/kernels/router/nvfp4_fused_router_kernel.py index 10c737f2..e3aeaafe 100644 --- a/dsv4/kernels/router/nvfp4_fused_router_kernel.py +++ b/dsv4/kernels/router/nvfp4_fused_router_kernel.py @@ -67,13 +67,10 @@ class Nvfp4FusedRouterKernel: self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE self.arch = "sm_100" - self.mma_inst_shape_mn = ( - cutlass.Int32(mma_tiler_mnk[0]), - cutlass.Int32(mma_tiler_mnk[1]), - ) + self.mma_inst_shape_mn = (mma_tiler_mnk[0], mma_tiler_mnk[1]) self.mma_inst_shape_mn_sfb = ( - self.mma_inst_shape_mn[0] // (2 if self.use_2cta_instrs else 1), - cute.round_up(self.mma_inst_shape_mn[1], 128), + mma_tiler_mnk[0] // (2 if self.use_2cta_instrs else 1), + cute.round_up(mma_tiler_mnk[1], 128), ) # 6-warp specialization (no scheduler warp for dense GEMM) @@ -111,16 +108,8 @@ class Nvfp4FusedRouterKernel: mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) mma_inst_tile_k = self.mma_tiler_mnk[2] // mma_inst_shape_k - self.mma_tiler = ( - self.mma_inst_shape_mn[0], - self.mma_inst_shape_mn[1], - cutlass.Int32(self.mma_tiler_mnk[2]), - ) - self.mma_tiler_sfb = ( - self.mma_inst_shape_mn_sfb[0], - self.mma_inst_shape_mn_sfb[1], - cutlass.Int32(self.mma_tiler_mnk[2]), - ) + self.mma_tiler = (self.mma_tiler_mnk[0], self.mma_tiler_mnk[1], self.mma_tiler_mnk[2]) + self.mma_tiler_sfb = (self.mma_inst_shape_mn_sfb[0], self.mma_inst_shape_mn_sfb[1], self.mma_tiler_mnk[2]) self.cta_tile_shape_mnk = ( self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), self.mma_tiler[1], @@ -133,10 +122,10 @@ class Nvfp4FusedRouterKernel: ) self.cluster_layout_vmnk = cute.tiled_divide( - cute.make_layout((cutlass.Int32(self.cluster_shape_mn[0]), cutlass.Int32(self.cluster_shape_mn[1]), cutlass.Int32(1))), + cute.make_layout((self.cluster_shape_mn[0], self.cluster_shape_mn[1], 1)), (tiled_mma.thr_id.shape,)) self.cluster_layout_sfb_vmnk = cute.tiled_divide( - cute.make_layout((cutlass.Int32(self.cluster_shape_mn[0]), cutlass.Int32(self.cluster_shape_mn[1]), cutlass.Int32(1))), + cute.make_layout((self.cluster_shape_mn[0], self.cluster_shape_mn[1], 1)), (tiled_mma_sfb.thr_id.shape,)) self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) @@ -287,67 +276,65 @@ class Nvfp4FusedRouterKernel: num_N_tiles = (N + cta_n - 1) // cta_n grid = (num_M_tiles * num_N_tiles, 1, 1) - # Setup tiled MMA and attributes on HOST side (outside JIT) - # Same pattern as fused_swiglu.py __call__ - # _setup_attributes calls cute.size(tiled_mma.shape_mnk, mode=[2]) - # which requires host-side execution (not inside @cute.jit) - tiled_mma = self._create_tiled_mma(a_dtype, a_major_mode, b_major_mode, sf_dtype) - tiled_mma_sfb = self._create_tiled_mma_sfb(a_dtype, a_major_mode, b_major_mode, sf_dtype) - self._setup_attributes(tiled_mma, tiled_mma_sfb, a_dtype, b_dtype, sf_dtype, c_dtype, c_layout) + @cute.jit + def _compiled_fn(mat_a, mat_b, scale_a, scale_b, mat_c): + # Create tiled MMA and setup inside JIT context + # (same pattern as fused_swiglu.py @cute.jit __call__) + # Plain int mma_tiler values work with cute.size() inside JIT + tiled_mma = self._create_tiled_mma(a_dtype, a_major_mode, b_major_mode, sf_dtype) + tiled_mma_sfb = self._create_tiled_mma_sfb(a_dtype, a_major_mode, b_major_mode, sf_dtype) + self._setup_attributes(tiled_mma, tiled_mma_sfb, a_dtype, b_dtype, sf_dtype, c_dtype, c_layout) - # TMA atoms (host side, same as fused_swiglu) - a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id) - a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) - tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( - a_op, mat_a, a_smem_layout, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape) + # TMA atoms (inside JIT, same as fused_swiglu) + a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, mat_a, a_smem_layout, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape) - b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id) - b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) - tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( - b_op, mat_b, b_smem_layout, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape) + b_op = sm100_utils.cluster_shape_to_tma_atom_B(self.cluster_shape_mn, tiled_mma.thr_id) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, mat_b, b_smem_layout, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape) - sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id) - sfa_smem_layout = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0)) - tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( - sfa_op, scale_a, sfa_smem_layout, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape, - internal_type=cutlass.Uint64) + sfa_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id) + sfa_smem_layout = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0)) + tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( + sfa_op, scale_a, sfa_smem_layout, self.mma_tiler, tiled_mma, self.cluster_layout_vmnk.shape, + internal_type=cutlass.Uint64) - sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma.thr_id) - sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0)) - tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( - sfb_op, scale_b, sfb_smem_layout, self.mma_tiler_sfb, tiled_mma_sfb, - self.cluster_layout_sfb_vmnk.shape, internal_type=cutlass.Uint64) + sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB(self.cluster_shape_mn, tiled_mma.thr_id) + sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0)) + tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( + sfb_op, scale_b, sfb_smem_layout, self.mma_tiler_sfb, tiled_mma_sfb, + self.cluster_layout_sfb_vmnk.shape, internal_type=cutlass.Uint64) - epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) - tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( - cpasync.CopyBulkTensorTileS2GOp(), mat_c, epi_smem_layout, self.epi_tile) + epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), mat_c, epi_smem_layout, self.epi_tile) - tile_sched_params = utils.PersistentTileSchedulerParams( - (cutlass.Int32(num_M_tiles), cutlass.Int32(num_N_tiles), cutlass.Int32(1)), - (1, 1, 1)) + tile_sched_params = utils.PersistentTileSchedulerParams( + (num_M_tiles, num_N_tiles, 1), (1, 1, 1)) - # Launch kernel directly (same as fused_swiglu pattern) - self._kernel( - tiled_mma, tiled_mma_sfb, - tma_atom_a, tma_tensor_a, tma_atom_b, tma_tensor_b, - tma_atom_sfa, tma_tensor_sfa, tma_atom_sfb, tma_tensor_sfb, - tma_atom_c, tma_tensor_c, - self.cluster_layout_vmnk, self.cluster_layout_sfb_vmnk, - self.a_smem_layout_staged, self.b_smem_layout_staged, - self.sfa_smem_layout_staged, self.sfb_smem_layout_staged, - self.c_smem_layout_staged, - self.epi_tile, - tile_sched_params, - M, N, K, gsa, gsb, - ).launch( - grid=grid, block=[self.threads_per_cta, 1, 1], - cluster=(*self.cluster_shape_mn, 1), - stream=stream, min_blocks_per_mp=1, - ) + self._kernel( + tiled_mma, tiled_mma_sfb, + tma_atom_a, tma_tensor_a, tma_atom_b, tma_tensor_b, + tma_atom_sfa, tma_tensor_sfa, tma_atom_sfb, tma_tensor_sfb, + tma_atom_c, tma_tensor_c, + self.cluster_layout_vmnk, self.cluster_layout_sfb_vmnk, + self.a_smem_layout_staged, self.b_smem_layout_staged, + self.sfa_smem_layout_staged, self.sfb_smem_layout_staged, + self.c_smem_layout_staged, + self.epi_tile, + tile_sched_params, + M, N, K, gsa, gsb, + ).launch( + grid=grid, block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + stream=stream, min_blocks_per_mp=1, + ) + + cute.compile(_compiled_fn, mat_a, mat_b, scale_a, scale_b, mat_c) - # ----------------------------------------------------------------- - # GPU kernel - # ----------------------------------------------------------------- @cute.kernel def _kernel(self, tiled_mma, tiled_mma_sfb, tma_atom_a, mA_mkl, tma_atom_b, mB_nkl,