Revert "feat(moe): Add is_act_and_mul=False support for Triton MoE kernels" (#31978)

This commit is contained in:
Michael Goin
2026-01-08 14:31:53 -05:00
committed by GitHub
parent 7508243249
commit 87e07a6b46
7 changed files with 9 additions and 191 deletions

View File

@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from enum import Enum
from math import prod, sqrt
from math import prod
from typing import final
import torch
@@ -575,35 +575,14 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
def activation(
self, activation: str, output: torch.Tensor, input: torch.Tensor
) -> None:
# Fused activations (SwiGLU-style): output is half the size of input
assert output.size(-1) * 2 == input.size(-1)
if activation == "silu":
assert output.size(-1) * 2 == input.size(-1)
torch.ops._C.silu_and_mul(output, input)
elif activation == "gelu":
assert output.size(-1) * 2 == input.size(-1)
torch.ops._C.gelu_and_mul(output, input)
elif activation == "swigluoai":
# alpha = 1.702, limit = 7.0
assert output.size(-1) * 2 == input.size(-1)
torch.ops._C.swigluoai_and_mul(output, input)
# Non-fused activations (is_act_and_mul=False): output same size as input
elif activation == "silu_no_mul":
assert output.size(-1) == input.size(-1)
# Use out= argument to avoid intermediate tensor
torch.sigmoid(input, out=output)
output.mul_(input)
elif activation == "gelu_no_mul":
assert output.size(-1) == input.size(-1)
# GELU(x) = 0.5 * x * (1 + erf(x / sqrt(2)))
# Use out= and in-place ops to avoid intermediate tensors
output.copy_(input).div_(sqrt(2))
torch.erf(output, out=output)
output.add_(1).mul_(input).mul_(0.5)
elif activation == "relu2_no_mul":
assert output.size(-1) == input.size(-1)
# ReLU²: clamp has out=, then in-place square
torch.clamp(input, min=0, out=output)
output.square_()
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")