rope_theta and max_position_embeddings from config (#1096)

Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: wnma3mz <wnma3mz@gmail.com>
This commit is contained in:
Antoni Baum
2023-09-20 13:35:11 -07:00
committed by GitHub
parent 6f2dd6c37e
commit 3302f0aef3
9 changed files with 140 additions and 62 deletions

View File

@@ -111,6 +111,8 @@ class BaiChuanAttention(nn.Module):
hidden_size: int,
num_heads: int,
position_embedding: str,
rope_theta: float = 10000,
max_position_embeddings: int = 8192,
):
super().__init__()
self.hidden_size = hidden_size
@@ -122,6 +124,8 @@ class BaiChuanAttention(nn.Module):
tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads
self.postion_embedding = position_embedding
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
# pylint: disable=invalid-name
self.W_pack = ColumnParallelLinear(
@@ -151,10 +155,13 @@ class BaiChuanAttention(nn.Module):
scaling, alibi_slopes)
else:
self.scaling = self.head_dim**-0.5
self.attn = PagedAttentionWithRoPE(self.num_heads,
self.head_dim,
self.scaling,
rotary_dim=self.head_dim)
self.attn = PagedAttentionWithRoPE(
self.num_heads,
self.head_dim,
self.scaling,
rotary_dim=self.head_dim,
base=self.rope_theta,
max_position=self.max_position_embeddings)
def forward(
self,
@@ -183,10 +190,15 @@ class BaiChuanDecoderLayer(nn.Module):
def __init__(self, config: BaiChuanConfig, position_embedding: str):
super().__init__()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
self.self_attn = BaiChuanAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
position_embedding=position_embedding,
rope_theta=rope_theta,
max_position_embeddings=max_position_embeddings,
)
self.mlp = BaiChuanMLP(
hidden_size=self.hidden_size,