fix: use from_dlpack + mark_layout_dynamic instead of non-existent to_cuTe_tensor in router

This commit is contained in:
2026-05-31 23:46:35 +00:00
parent cb2ca8591f
commit 6375e54396

View File

@@ -114,11 +114,11 @@ class DenseRouterDecodeKernel:
self._setup_attributes()
tiled_mma = self._tiled_mma
X_cu = cutlass_torch.to_cuTe_tensor(X, major_mode=self.a_major_mode)
W_cu = cutlass_torch.to_cuTe_tensor(W_gate, major_mode=self.b_major_mode)
e_bias_cu = cutlass_torch.to_cuTe_tensor(e_bias)
out_w_cu = cutlass_torch.to_cuTe_tensor(out_w)
out_ids_cu = cutlass_torch.to_cuTe_tensor(out_ids)
X_cu = cutlass_torch.from_dlpack(X).mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(X))
W_cu = cutlass_torch.from_dlpack(W_gate).mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(W_gate))
e_bias_cu = cutlass_torch.from_dlpack(e_bias).mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(e_bias))
out_w_cu = cutlass_torch.from_dlpack(out_w).mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(out_w))
out_ids_cu = cutlass_torch.from_dlpack(out_ids).mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(out_ids))
a_smem = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)