Update rope_scaling to rope_parameters in preparation for Transformers v5 (#28542)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -86,7 +86,7 @@ class PhiMoEConfig(PretrainedConfig):
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=1e6,
|
||||
rope_parameters=None,
|
||||
sliding_window=None,
|
||||
attention_dropout=0.0,
|
||||
num_experts_per_tok=2,
|
||||
@@ -119,7 +119,9 @@ class PhiMoEConfig(PretrainedConfig):
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
if rope_parameters is None:
|
||||
rope_theta = kwargs.pop("rope_theta", 1e6)
|
||||
rope_parameters = {"rope_type": "default", "rope_theta": rope_theta}
|
||||
self.attention_dropout = attention_dropout
|
||||
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
@@ -302,12 +304,11 @@ class PhiMoEAttention(nn.Module):
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_parameters: dict,
|
||||
head_dim: int | None = None,
|
||||
max_position: int = 4096 * 32,
|
||||
rope_theta: float = 10000,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
rope_scaling: dict | None = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
@@ -332,8 +333,6 @@ class PhiMoEAttention(nn.Module):
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
@@ -355,9 +354,8 @@ class PhiMoEAttention(nn.Module):
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position,
|
||||
base=int(self.rope_theta),
|
||||
rope_parameters=rope_parameters,
|
||||
is_neox_style=True,
|
||||
rope_scaling=self.rope_scaling,
|
||||
)
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
@@ -393,7 +391,6 @@ class PhiMoEDecoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
# Requires transformers > 4.32.0
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
self.self_attn = PhiMoEAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
@@ -402,10 +399,9 @@ class PhiMoEDecoderLayer(nn.Module):
|
||||
head_dim=getattr(
|
||||
config, "head_dim", self.hidden_size // config.num_attention_heads
|
||||
),
|
||||
rope_theta=rope_theta,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
rope_scaling=config.rope_scaling,
|
||||
rope_parameters=config.rope_parameters,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
)
|
||||
self.block_sparse_moe = PhiMoE(
|
||||
|
||||
Reference in New Issue
Block a user