fix: compute num_tma_load_bytes inside cute.compile context
This commit is contained in:
@@ -113,6 +113,12 @@ class DenseRouterDecodeKernel:
|
||||
def _compiled_fn(X, W_gate, e_bias, out_w, out_ids):
|
||||
self._setup_attributes()
|
||||
tiled_mma = self._tiled_mma
|
||||
atom_thr_size = cute.size(tiled_mma.thr_id.shape)
|
||||
a_smem_0 = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
|
||||
a_copy = cute.size_in_bytes(self.a_dtype, a_smem_0)
|
||||
b_smem_0 = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0))
|
||||
b_copy = cute.size_in_bytes(self.b_dtype, b_smem_0)
|
||||
self.num_tma_load_bytes = (a_copy + b_copy) * atom_thr_size
|
||||
|
||||
# Inside cute.compile, arguments are already CuTe tensors
|
||||
X_cu = X
|
||||
|
||||
Reference in New Issue
Block a user