diff --git a/cutedsl/shared_expert_pipeline.py b/cutedsl/shared_expert_pipeline.py index 7d1429bd..ce9aa4be 100644 --- a/cutedsl/shared_expert_pipeline.py +++ b/cutedsl/shared_expert_pipeline.py @@ -22,13 +22,6 @@ import torch from cutedsl.bridge import ( quantize_activation_nvfp4, - - -class _SharedExpertApply(torch.autograd.Function): - """Custom autograd function to make CuTeDSL runner opaque to torch.compile.""" - @staticmethod - def forward(ctx, runner, hidden_states): - return runner._run_impl(hidden_states) quantize_to_nvfp4, make_b_k_major, assemble_scales_3d_side, @@ -40,6 +33,13 @@ from cutedsl.kernel.moe.torch_scaled_grouped_mm import ( ) +class _SharedExpertApply(torch.autograd.Function): + """Custom autograd function to make CuTeDSL runner opaque to torch.compile.""" + @staticmethod + def forward(ctx, runner, hidden_states): + return runner._run_impl(hidden_states) + + class CuTeDSLSharedExpertRunner: """NVFP4 shared expert runner using CuTeDSL GEMM (num_groups=1).