Add activation registry (#126)

This commit is contained in:
Woosuk Kwon
2023-05-25 00:09:07 -07:00
committed by GitHub
parent 057daef778
commit 4a151dd453
5 changed files with 22 additions and 13 deletions

View File

@@ -27,6 +27,7 @@ from torch import nn
from transformers import GPT2Config
from cacheflow.model_executor.input_metadata import InputMetadata
from cacheflow.model_executor.layers.activation import get_act_fn
from cacheflow.model_executor.layers.attention import GPTCacheFlowAttention
from cacheflow.model_executor.layers.sampler import Sampler
from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator,
@@ -92,12 +93,7 @@ class GPT2MLP(nn.Module):
self.c_proj = RowParallelLinear(intermediate_size, hidden_size,
bias=True, input_is_parallel=True,
perform_initialization=False)
act_fn = config.activation_function
if act_fn != "gelu_new":
raise ValueError(f"Unsupported activation: {act_fn}. "
"GPT-2 only supports gelu_new for now.")
self.act = torch.nn.GELU(approximate="tanh")
self.act = get_act_fn(config.activation_function)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.c_fc(hidden_states)