fix: use from_dlpack + mark_layout_dynamic instead of non-existent to_cuTe_tensor in router
This commit is contained in:
@@ -114,11 +114,11 @@ class DenseRouterDecodeKernel:
|
||||
self._setup_attributes()
|
||||
tiled_mma = self._tiled_mma
|
||||
|
||||
X_cu = cutlass_torch.to_cuTe_tensor(X, major_mode=self.a_major_mode)
|
||||
W_cu = cutlass_torch.to_cuTe_tensor(W_gate, major_mode=self.b_major_mode)
|
||||
e_bias_cu = cutlass_torch.to_cuTe_tensor(e_bias)
|
||||
out_w_cu = cutlass_torch.to_cuTe_tensor(out_w)
|
||||
out_ids_cu = cutlass_torch.to_cuTe_tensor(out_ids)
|
||||
X_cu = cutlass_torch.from_dlpack(X).mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(X))
|
||||
W_cu = cutlass_torch.from_dlpack(W_gate).mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(W_gate))
|
||||
e_bias_cu = cutlass_torch.from_dlpack(e_bias).mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(e_bias))
|
||||
out_w_cu = cutlass_torch.from_dlpack(out_w).mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(out_w))
|
||||
out_ids_cu = cutlass_torch.from_dlpack(out_ids).mark_layout_dynamic(leading_dim=cutlass_torch.get_leading_dim(out_ids))
|
||||
|
||||
a_smem = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0))
|
||||
a_op = sm100_utils.cluster_shape_to_tma_atom_A(self.cluster_shape_mn, tiled_mma.thr_id)
|
||||
|
||||
Reference in New Issue
Block a user