diff --git a/dsv4/kernels/router/dense_router_decode_kernel.py b/dsv4/kernels/router/dense_router_decode_kernel.py index c44bfc97..a68e3e32 100644 --- a/dsv4/kernels/router/dense_router_decode_kernel.py +++ b/dsv4/kernels/router/dense_router_decode_kernel.py @@ -114,11 +114,11 @@ class DenseRouterDecodeKernel: self._setup_attributes() tiled_mma = self._tiled_mma - X_cu = cutlass_torch.to_cuTe_tensor(X, major_mode=self.a_major_mode) - W_cu = cutlass_torch.to_cuTe_tensor(W_gate, major_mode=self.b_major_mode) - e_bias_cu = cutlass_torch.to_cuTe_tensor(e_bias) - out_w_cu = cutlass_torch.to_cuTe_tensor(out_w) - out_ids_cu = cutlass_torch.to_cuTe_tensor(out_ids) + X_cu = cutlass_torch.from_dlpack(X).mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(X)) + W_cu = cutlass_torch.from_dlpack(W_gate).mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(W_gate)) + e_bias_cu = cutlass_torch.from_dlpack(e_bias).mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(e_bias)) + out_w_cu = cutlass_torch.from_dlpack(out_w).mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(out_w)) + out_ids_cu = cutlass_torch.from_dlpack(out_ids).mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(out_ids)) a_smem = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)