Add activation registry (#126)
This commit is contained in:
@@ -27,6 +27,7 @@ from torch import nn
|
||||
from transformers import GPT2Config
|
||||
|
||||
from cacheflow.model_executor.input_metadata import InputMetadata
|
||||
from cacheflow.model_executor.layers.activation import get_act_fn
|
||||
from cacheflow.model_executor.layers.attention import GPTCacheFlowAttention
|
||||
from cacheflow.model_executor.layers.sampler import Sampler
|
||||
from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
@@ -92,12 +93,7 @@ class GPT2MLP(nn.Module):
|
||||
self.c_proj = RowParallelLinear(intermediate_size, hidden_size,
|
||||
bias=True, input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
|
||||
act_fn = config.activation_function
|
||||
if act_fn != "gelu_new":
|
||||
raise ValueError(f"Unsupported activation: {act_fn}. "
|
||||
"GPT-2 only supports gelu_new for now.")
|
||||
self.act = torch.nn.GELU(approximate="tanh")
|
||||
self.act = get_act_fn(config.activation_function)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.c_fc(hidden_states)
|
||||
|
||||
@@ -26,6 +26,7 @@ from torch import nn
|
||||
from transformers import GPTNeoXConfig
|
||||
|
||||
from cacheflow.model_executor.input_metadata import InputMetadata
|
||||
from cacheflow.model_executor.layers.activation import get_act_fn
|
||||
from cacheflow.model_executor.layers.attention import GPTNeoXCacheFlowAttention
|
||||
from cacheflow.model_executor.layers.sampler import Sampler
|
||||
from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
@@ -94,10 +95,7 @@ class GPTNeoXMLP(nn.Module):
|
||||
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size, config.hidden_size,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
if config.hidden_act != 'gelu':
|
||||
raise ValueError(f'Unsupported activation: {config.hidden_act}. '
|
||||
'Only gelu is supported for now.')
|
||||
self.act = torch.nn.GELU()
|
||||
self.act = get_act_fn(config.hidden_act)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states, _ = self.dense_h_to_4h(hidden_states)
|
||||
|
||||
@@ -26,6 +26,7 @@ from torch import nn
|
||||
from transformers import OPTConfig
|
||||
|
||||
from cacheflow.model_executor.input_metadata import InputMetadata
|
||||
from cacheflow.model_executor.layers.activation import get_act_fn
|
||||
from cacheflow.model_executor.layers.attention import GPTCacheFlowAttention
|
||||
from cacheflow.model_executor.layers.sampler import Sampler
|
||||
from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
@@ -105,8 +106,7 @@ class OPTDecoderLayer(nn.Module):
|
||||
bias=config.enable_bias,
|
||||
)
|
||||
self.do_layer_norm_before = config.do_layer_norm_before
|
||||
assert config.activation_function == 'relu'
|
||||
self.activation_fn = nn.ReLU()
|
||||
self.activation_fn = get_act_fn(config.activation_function)
|
||||
|
||||
self.self_attn_layer_norm = nn.LayerNorm(
|
||||
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)
|
||||
|
||||
Reference in New Issue
Block a user