Migrate InternLMForCausalLM to LlamaForCausalLM (#2860)
Co-authored-by: Roy <jasonailu87@gmail.com>
This commit is contained in:
@@ -91,6 +91,7 @@ class LlamaAttention(nn.Module):
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
bias: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@@ -120,13 +121,13 @@ class LlamaAttention(nn.Module):
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
bias=bias,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
bias=bias,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
|
||||
@@ -179,6 +180,7 @@ class LlamaDecoderLayer(nn.Module):
|
||||
rope_scaling=rope_scaling,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
linear_method=linear_method,
|
||||
bias=getattr(config, "bias", False),
|
||||
)
|
||||
self.mlp = LlamaMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
|
||||
Reference in New Issue
Block a user