From 0a7769972fb0baaafa08950edda4abcfb36524dc Mon Sep 17 00:00:00 2001 From: biondizzle Date: Tue, 19 May 2026 07:18:10 +0000 Subject: [PATCH] Fix garbled shared_expert_pipeline.py: imports/class were merged --- cutedsl/shared_expert_pipeline.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/cutedsl/shared_expert_pipeline.py b/cutedsl/shared_expert_pipeline.py index 7d1429bd..ce9aa4be 100644 --- a/cutedsl/shared_expert_pipeline.py +++ b/cutedsl/shared_expert_pipeline.py @@ -22,13 +22,6 @@ import torch from cutedsl.bridge import ( quantize_activation_nvfp4, - - -class _SharedExpertApply(torch.autograd.Function): - """Custom autograd function to make CuTeDSL runner opaque to torch.compile.""" - @staticmethod - def forward(ctx, runner, hidden_states): - return runner._run_impl(hidden_states) quantize_to_nvfp4, make_b_k_major, assemble_scales_3d_side, @@ -40,6 +33,13 @@ from cutedsl.kernel.moe.torch_scaled_grouped_mm import ( ) +class _SharedExpertApply(torch.autograd.Function): + """Custom autograd function to make CuTeDSL runner opaque to torch.compile.""" + @staticmethod + def forward(ctx, runner, hidden_states): + return runner._run_impl(hidden_states) + + class CuTeDSLSharedExpertRunner: """NVFP4 shared expert runner using CuTeDSL GEMM (num_groups=1).