Add activation registry (#126)
This commit is contained in:
@@ -26,6 +26,7 @@ from torch import nn
|
||||
from transformers import GPTNeoXConfig
|
||||
|
||||
from cacheflow.model_executor.input_metadata import InputMetadata
|
||||
from cacheflow.model_executor.layers.activation import get_act_fn
|
||||
from cacheflow.model_executor.layers.attention import GPTNeoXCacheFlowAttention
|
||||
from cacheflow.model_executor.layers.sampler import Sampler
|
||||
from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
@@ -94,10 +95,7 @@ class GPTNeoXMLP(nn.Module):
|
||||
self.dense_4h_to_h = RowParallelLinear(config.intermediate_size, config.hidden_size,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
if config.hidden_act != 'gelu':
|
||||
raise ValueError(f'Unsupported activation: {config.hidden_act}. '
|
||||
'Only gelu is supported for now.')
|
||||
self.act = torch.nn.GELU()
|
||||
self.act = get_act_fn(config.hidden_act)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states, _ = self.dense_h_to_4h(hidden_states)
|
||||
|
||||
Reference in New Issue
Block a user