fix: compute num_tma_load_bytes inside cute.compile context

This commit is contained in:
2026-05-31 23:53:13 +00:00
parent 1bc0da0f35
commit 3b5b9f487c

View File

@@ -113,6 +113,12 @@ class DenseRouterDecodeKernel:
def _compiled_fn(X, W_gate, e_bias, out_w, out_ids):
self._setup_attributes()
tiled_mma = self._tiled_mma
atom_thr_size = cute.size(tiled_mma.thr_id.shape)
a_smem_0 = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
a_copy = cute.size_in_bytes(self.a_dtype, a_smem_0)
b_smem_0 = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
b_copy = cute.size_in_bytes(self.b_dtype, b_smem_0)
self.num_tma_load_bytes = (a_copy + b_copy) * atom_thr_size
# Inside cute.compile, arguments are already CuTe tensors
X_cu = X