feat(moe): Add is_act_and_mul=False support for Triton MoE kernels (#31645)

Signed-off-by: rabi <ramishra@redhat.com>
This commit is contained in:
Rabi Mishra
2026-01-08 07:57:09 +05:30
committed by GitHub
parent 0d7667419f
commit 25eef3dc2e
7 changed files with 191 additions and 9 deletions

View File

@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
from collections.abc import Callable
from dataclasses import dataclass
from enum import Enum
from math import prod
from math import prod, sqrt
from typing import final
import torch
@@ -575,14 +575,35 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
def activation(
self, activation: str, output: torch.Tensor, input: torch.Tensor
) -> None:
assert output.size(-1) * 2 == input.size(-1)
# Fused activations (SwiGLU-style): output is half the size of input
if activation == "silu":
assert output.size(-1) * 2 == input.size(-1)
torch.ops._C.silu_and_mul(output, input)
elif activation == "gelu":
assert output.size(-1) * 2 == input.size(-1)
torch.ops._C.gelu_and_mul(output, input)
elif activation == "swigluoai":
# alpha = 1.702, limit = 7.0
assert output.size(-1) * 2 == input.size(-1)
torch.ops._C.swigluoai_and_mul(output, input)
# Non-fused activations (is_act_and_mul=False): output same size as input
elif activation == "silu_no_mul":
assert output.size(-1) == input.size(-1)
# Use out= argument to avoid intermediate tensor
torch.sigmoid(input, out=output)
output.mul_(input)
elif activation == "gelu_no_mul":
assert output.size(-1) == input.size(-1)
# GELU(x) = 0.5 * x * (1 + erf(x / sqrt(2)))
# Use out= and in-place ops to avoid intermediate tensors
output.copy_(input).div_(sqrt(2))
torch.erf(output, out=output)
output.add_(1).mul_(input).mul_(0.5)
elif activation == "relu2_no_mul":
assert output.size(-1) == input.size(-1)
# ReLU²: clamp has out=, then in-place square
torch.clamp(input, min=0, out=output)
output.square_()
else:
raise ValueError(f"Unsupported FusedMoe activation: {activation}")