rope_theta and max_position_embeddings from config (#1096)

Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: wnma3mz <wnma3mz@gmail.com>
This commit is contained in:
Antoni Baum
2023-09-20 13:35:11 -07:00
committed by GitHub
parent 6f2dd6c37e
commit 3302f0aef3
9 changed files with 140 additions and 62 deletions

View File

@@ -76,8 +76,13 @@ class QWenMLP(nn.Module):
class QWenAttention(nn.Module):
def __init__(self, hidden_size: int, num_heads: int,
max_position_embeddings: int):
def __init__(
self,
hidden_size: int,
num_heads: int,
max_position_embeddings: int,
rope_theta: float = 10000,
):
super().__init__()
self.hidden_size = hidden_size
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
@@ -109,6 +114,7 @@ class QWenAttention(nn.Module):
self.head_dim,
self.scaling,
rotary_dim=self.head_dim,
base=rope_theta,
max_position=max_position_embeddings,
)
@@ -137,8 +143,11 @@ class QWenBlock(nn.Module):
super().__init__()
self.ln_1 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
self.attn = QWenAttention(config.n_embd, config.num_attention_heads,
config.max_position_embeddings)
rope_theta = getattr(config, "rope_theta", 10000)
self.attn = QWenAttention(config.n_embd,
config.num_attention_heads,
config.max_position_embeddings,
rope_theta=rope_theta)
self.ln_2 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)