diff --git a/dsv4/kernels/router/dense_router_decode_kernel.py b/dsv4/kernels/router/dense_router_decode_kernel.py index 28e6ee37..43e9d651 100644 --- a/dsv4/kernels/router/dense_router_decode_kernel.py +++ b/dsv4/kernels/router/dense_router_decode_kernel.py @@ -113,6 +113,12 @@ class DenseRouterDecodeKernel: def _compiled_fn(X, W_gate, e_bias, out_w, out_ids): self._setup_attributes() tiled_mma = self._tiled_mma + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + a_smem_0 = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + a_copy = cute.size_in_bytes(self.a_dtype, a_smem_0) + b_smem_0 = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + b_copy = cute.size_in_bytes(self.b_dtype, b_smem_0) + self.num_tma_load_bytes = (a_copy + b_copy) * atom_thr_size # Inside cute.compile, arguments are already CuTe tensors X_cu = X