Add activation registry (#126)

This commit is contained in:
Woosuk Kwon
2023-05-25 00:09:07 -07:00
committed by GitHub
parent 057daef778
commit 4a151dd453
5 changed files with 22 additions and 13 deletions

View File

@@ -26,6 +26,7 @@ from torch import nn
from transformers import OPTConfig
from cacheflow.model_executor.input_metadata import InputMetadata
from cacheflow.model_executor.layers.activation import get_act_fn
from cacheflow.model_executor.layers.attention import GPTCacheFlowAttention
from cacheflow.model_executor.layers.sampler import Sampler
from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator,
@@ -105,8 +106,7 @@ class OPTDecoderLayer(nn.Module):
bias=config.enable_bias,
)
self.do_layer_norm_before = config.do_layer_norm_before
assert config.activation_function == 'relu'
self.activation_fn = nn.ReLU()
self.activation_fn = get_act_fn(config.activation_function)
self.self_attn_layer_norm = nn.LayerNorm(
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine)