feat(moe): Add is_act_and_mul=False support for Triton MoE kernels (#31645)
Signed-off-by: rabi <ramishra@redhat.com>
This commit is contained in:
@@ -4,7 +4,7 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from math import prod
|
||||
from math import prod, sqrt
|
||||
from typing import final
|
||||
|
||||
import torch
|
||||
@@ -575,14 +575,35 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
def activation(
|
||||
self, activation: str, output: torch.Tensor, input: torch.Tensor
|
||||
) -> None:
|
||||
assert output.size(-1) * 2 == input.size(-1)
|
||||
# Fused activations (SwiGLU-style): output is half the size of input
|
||||
if activation == "silu":
|
||||
assert output.size(-1) * 2 == input.size(-1)
|
||||
torch.ops._C.silu_and_mul(output, input)
|
||||
elif activation == "gelu":
|
||||
assert output.size(-1) * 2 == input.size(-1)
|
||||
torch.ops._C.gelu_and_mul(output, input)
|
||||
elif activation == "swigluoai":
|
||||
# alpha = 1.702, limit = 7.0
|
||||
assert output.size(-1) * 2 == input.size(-1)
|
||||
torch.ops._C.swigluoai_and_mul(output, input)
|
||||
# Non-fused activations (is_act_and_mul=False): output same size as input
|
||||
elif activation == "silu_no_mul":
|
||||
assert output.size(-1) == input.size(-1)
|
||||
# Use out= argument to avoid intermediate tensor
|
||||
torch.sigmoid(input, out=output)
|
||||
output.mul_(input)
|
||||
elif activation == "gelu_no_mul":
|
||||
assert output.size(-1) == input.size(-1)
|
||||
# GELU(x) = 0.5 * x * (1 + erf(x / sqrt(2)))
|
||||
# Use out= and in-place ops to avoid intermediate tensors
|
||||
output.copy_(input).div_(sqrt(2))
|
||||
torch.erf(output, out=output)
|
||||
output.add_(1).mul_(input).mul_(0.5)
|
||||
elif activation == "relu2_no_mul":
|
||||
assert output.size(-1) == input.size(-1)
|
||||
# ReLU²: clamp has out=, then in-place square
|
||||
torch.clamp(input, min=0, out=output)
|
||||
output.square_()
|
||||
else:
|
||||
raise ValueError(f"Unsupported FusedMoe activation: {activation}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user