Fix garbled shared_expert_pipeline.py: imports/class were merged

This commit is contained in:
2026-05-19 07:18:10 +00:00
parent 87453a53b0
commit 0a7769972f

View File

@@ -22,13 +22,6 @@ import torch
from cutedsl.bridge import (
quantize_activation_nvfp4,
class _SharedExpertApply(torch.autograd.Function):
"""Custom autograd function to make CuTeDSL runner opaque to torch.compile."""
@staticmethod
def forward(ctx, runner, hidden_states):
return runner._run_impl(hidden_states)
quantize_to_nvfp4,
make_b_k_major,
assemble_scales_3d_side,
@@ -40,6 +33,13 @@ from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
)
class _SharedExpertApply(torch.autograd.Function):
"""Custom autograd function to make CuTeDSL runner opaque to torch.compile."""
@staticmethod
def forward(ctx, runner, hidden_states):
return runner._run_impl(hidden_states)
class CuTeDSLSharedExpertRunner:
"""NVFP4 shared expert runner using CuTeDSL GEMM (num_groups=1).