rope_theta and max_position_embeddings from config (#1096)
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu> Co-authored-by: wnma3mz <wnma3mz@gmail.com>
This commit is contained in:
@@ -76,8 +76,13 @@ class QWenMLP(nn.Module):
|
||||
|
||||
class QWenAttention(nn.Module):
|
||||
|
||||
def __init__(self, hidden_size: int, num_heads: int,
|
||||
max_position_embeddings: int):
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
max_position_embeddings: int,
|
||||
rope_theta: float = 10000,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size(
|
||||
@@ -109,6 +114,7 @@ class QWenAttention(nn.Module):
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
rotary_dim=self.head_dim,
|
||||
base=rope_theta,
|
||||
max_position=max_position_embeddings,
|
||||
)
|
||||
|
||||
@@ -137,8 +143,11 @@ class QWenBlock(nn.Module):
|
||||
super().__init__()
|
||||
self.ln_1 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||
|
||||
self.attn = QWenAttention(config.n_embd, config.num_attention_heads,
|
||||
config.max_position_embeddings)
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
self.attn = QWenAttention(config.n_embd,
|
||||
config.num_attention_heads,
|
||||
config.max_position_embeddings,
|
||||
rope_theta=rope_theta)
|
||||
|
||||
self.ln_2 = RMSNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user