[Model]: Support internlm3 (#12037)

This commit is contained in:
RunningLeon
2025-01-15 19:35:17 +08:00
committed by GitHub
parent 3adf0ffda8
commit 97eb97b5a4
4 changed files with 28 additions and 15 deletions

View File

@@ -97,20 +97,19 @@ class LlamaMLP(nn.Module):
class LlamaAttention(nn.Module):
def __init__(
self,
config: LlamaConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
) -> None:
def __init__(self,
config: LlamaConfig,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
quant_config: Optional[QuantizationConfig] = None,
bias: bool = False,
cache_config: Optional[CacheConfig] = None,
prefix: str = "",
bias_o_proj: bool = False) -> None:
super().__init__()
layer_idx = extract_layer_index(prefix)
self.hidden_size = hidden_size
@@ -150,7 +149,7 @@ class LlamaAttention(nn.Module):
self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
output_size=hidden_size,
bias=bias,
bias=bias_o_proj,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
@@ -231,6 +230,11 @@ class LlamaDecoderLayer(nn.Module):
# Support internlm/internlm-7b with bias
attention_bias = getattr(config, "attention_bias", False) or getattr(
config, "bias", False)
bias_o_proj = attention_bias
# support internlm/internlm3-8b with qkv_bias
if hasattr(config, 'qkv_bias'):
attention_bias = config.qkv_bias
self.self_attn = LlamaAttention(
config=config,
hidden_size=self.hidden_size,
@@ -242,6 +246,7 @@ class LlamaDecoderLayer(nn.Module):
max_position_embeddings=max_position_embeddings,
quant_config=quant_config,
bias=attention_bias,
bias_o_proj=bias_o_proj,
cache_config=cache_config,
prefix=f"{prefix}.self_attn",
)