Fix garbled shared_expert_pipeline.py: imports/class were merged
This commit is contained in:
@@ -22,13 +22,6 @@ import torch
|
||||
|
||||
from cutedsl.bridge import (
|
||||
quantize_activation_nvfp4,
|
||||
|
||||
|
||||
class _SharedExpertApply(torch.autograd.Function):
|
||||
"""Custom autograd function to make CuTeDSL runner opaque to torch.compile."""
|
||||
@staticmethod
|
||||
def forward(ctx, runner, hidden_states):
|
||||
return runner._run_impl(hidden_states)
|
||||
quantize_to_nvfp4,
|
||||
make_b_k_major,
|
||||
assemble_scales_3d_side,
|
||||
@@ -40,6 +33,13 @@ from cutedsl.kernel.moe.torch_scaled_grouped_mm import (
|
||||
)
|
||||
|
||||
|
||||
class _SharedExpertApply(torch.autograd.Function):
|
||||
"""Custom autograd function to make CuTeDSL runner opaque to torch.compile."""
|
||||
@staticmethod
|
||||
def forward(ctx, runner, hidden_states):
|
||||
return runner._run_impl(hidden_states)
|
||||
|
||||
|
||||
class CuTeDSLSharedExpertRunner:
|
||||
"""NVFP4 shared expert runner using CuTeDSL GEMM (num_groups=1).
|
||||
|
||||
|
||||
Reference in New Issue
Block a user