[torch.compile] Fine-grained CustomOp enabling mechanism (#9300)

This commit is contained in:
Luka Govedič
2024-10-17 14:36:37 -04:00
committed by GitHub
parent 7871659abb
commit 0f41fbe5a3
8 changed files with 220 additions and 21 deletions

View File

@@ -11,11 +11,13 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import LazyDict
@CustomOp.register("fatrelu_and_mul")
class FatreluAndMul(CustomOp):
"""An activation function for FATReLU.
The function computes x -> FATReLU(x[:d]) * x[d:] where
d = x.shape[-1] // 2.
This is used in openbmb/MiniCPM-S-1B-sft.
@@ -40,6 +42,7 @@ class FatreluAndMul(CustomOp):
return self.forward_native(x)
@CustomOp.register("silu_and_mul")
class SiluAndMul(CustomOp):
"""An activation function for SwiGLU.
@@ -74,6 +77,7 @@ class SiluAndMul(CustomOp):
return out
@CustomOp.register("gelu_and_mul")
class GeluAndMul(CustomOp):
"""An activation function for GeGLU.
@@ -123,6 +127,7 @@ class GeluAndMul(CustomOp):
return f'approximate={repr(self.approximate)}'
@CustomOp.register("gelu_new")
class NewGELU(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
@@ -144,6 +149,7 @@ class NewGELU(CustomOp):
return ops.gelu_new(x)
@CustomOp.register("gelu_fast")
class FastGELU(CustomOp):
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
@@ -164,8 +170,8 @@ class FastGELU(CustomOp):
return ops.gelu_fast(x)
@CustomOp.register("quick_gelu")
class QuickGELU(CustomOp):
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
@@ -189,6 +195,7 @@ class QuickGELU(CustomOp):
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
@CustomOp.register("relu2")
class ReLUSquaredActivation(CustomOp):
"""
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
@@ -244,15 +251,22 @@ class ScaledActivation(nn.Module):
param_data.copy_(loaded_weight)
_ACTIVATION_REGISTRY = {
"gelu": nn.GELU(),
"gelu_fast": FastGELU(),
"gelu_new": NewGELU(),
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
"relu": nn.ReLU(),
"relu2": ReLUSquaredActivation(),
"quick_gelu": QuickGELU(),
}
_ACTIVATION_REGISTRY = LazyDict({
"gelu":
lambda: nn.GELU(),
"gelu_fast":
lambda: FastGELU(),
"gelu_new":
lambda: NewGELU(),
"gelu_pytorch_tanh":
lambda: nn.GELU(approximate="tanh"),
"relu":
lambda: nn.ReLU(),
"relu2":
lambda: ReLUSquaredActivation(),
"quick_gelu":
lambda: QuickGELU(),
})
def get_act_fn(