[torch.compile] Fine-grained CustomOp enabling mechanism (#9300)
This commit is contained in:
@@ -11,11 +11,13 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.utils import LazyDict
|
||||
|
||||
|
||||
@CustomOp.register("fatrelu_and_mul")
|
||||
class FatreluAndMul(CustomOp):
|
||||
"""An activation function for FATReLU.
|
||||
|
||||
|
||||
The function computes x -> FATReLU(x[:d]) * x[d:] where
|
||||
d = x.shape[-1] // 2.
|
||||
This is used in openbmb/MiniCPM-S-1B-sft.
|
||||
@@ -40,6 +42,7 @@ class FatreluAndMul(CustomOp):
|
||||
return self.forward_native(x)
|
||||
|
||||
|
||||
@CustomOp.register("silu_and_mul")
|
||||
class SiluAndMul(CustomOp):
|
||||
"""An activation function for SwiGLU.
|
||||
|
||||
@@ -74,6 +77,7 @@ class SiluAndMul(CustomOp):
|
||||
return out
|
||||
|
||||
|
||||
@CustomOp.register("gelu_and_mul")
|
||||
class GeluAndMul(CustomOp):
|
||||
"""An activation function for GeGLU.
|
||||
|
||||
@@ -123,6 +127,7 @@ class GeluAndMul(CustomOp):
|
||||
return f'approximate={repr(self.approximate)}'
|
||||
|
||||
|
||||
@CustomOp.register("gelu_new")
|
||||
class NewGELU(CustomOp):
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -144,6 +149,7 @@ class NewGELU(CustomOp):
|
||||
return ops.gelu_new(x)
|
||||
|
||||
|
||||
@CustomOp.register("gelu_fast")
|
||||
class FastGELU(CustomOp):
|
||||
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@@ -164,8 +170,8 @@ class FastGELU(CustomOp):
|
||||
return ops.gelu_fast(x)
|
||||
|
||||
|
||||
@CustomOp.register("quick_gelu")
|
||||
class QuickGELU(CustomOp):
|
||||
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/activations.py#L90
|
||||
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
@@ -189,6 +195,7 @@ class QuickGELU(CustomOp):
|
||||
# def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
|
||||
@CustomOp.register("relu2")
|
||||
class ReLUSquaredActivation(CustomOp):
|
||||
"""
|
||||
Applies the relu^2 activation introduced in https://arxiv.org/abs/2109.08668v2
|
||||
@@ -244,15 +251,22 @@ class ScaledActivation(nn.Module):
|
||||
param_data.copy_(loaded_weight)
|
||||
|
||||
|
||||
_ACTIVATION_REGISTRY = {
|
||||
"gelu": nn.GELU(),
|
||||
"gelu_fast": FastGELU(),
|
||||
"gelu_new": NewGELU(),
|
||||
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
|
||||
"relu": nn.ReLU(),
|
||||
"relu2": ReLUSquaredActivation(),
|
||||
"quick_gelu": QuickGELU(),
|
||||
}
|
||||
_ACTIVATION_REGISTRY = LazyDict({
|
||||
"gelu":
|
||||
lambda: nn.GELU(),
|
||||
"gelu_fast":
|
||||
lambda: FastGELU(),
|
||||
"gelu_new":
|
||||
lambda: NewGELU(),
|
||||
"gelu_pytorch_tanh":
|
||||
lambda: nn.GELU(approximate="tanh"),
|
||||
"relu":
|
||||
lambda: nn.ReLU(),
|
||||
"relu2":
|
||||
lambda: ReLUSquaredActivation(),
|
||||
"quick_gelu":
|
||||
lambda: QuickGELU(),
|
||||
})
|
||||
|
||||
|
||||
def get_act_fn(
|
||||
|
||||
Reference in New Issue
Block a user