Update rope_scaling to rope_parameters in preparation for Transformers v5 (#28542)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-11-19 18:06:36 +01:00
committed by GitHub
parent d44e9df7d4
commit a8b70304d6
104 changed files with 542 additions and 910 deletions

View File

@@ -108,8 +108,7 @@ class FlashConfig(PretrainedConfig):
eos_token_id=100001,
pretraining_tp=1,
tie_word_embeddings=False,
rope_theta=1000000.0,
rope_scaling=None,
rope_parameters=None,
attention_bias=False,
attention_dropout=0.0,
mla_scale_q_lora=False,
@@ -162,8 +161,13 @@ class FlashConfig(PretrainedConfig):
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
# Try to set `rope_scaling` if available, otherwise use `rope_parameters`
rope_scaling = kwargs.pop("rope_scaling", None)
rope_parameters = rope_scaling or rope_parameters or {"rope_type": "default"}
rope_theta = kwargs.pop("rope_theta", 1000000.0)
if "rope_theta" not in rope_parameters:
rope_parameters["rope_theta"] = rope_theta
self.rope_parameters = rope_parameters
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
self.mla_scale_q_lora = mla_scale_q_lora
@@ -336,15 +340,7 @@ class FlashDecoderLayer(nn.Module):
super().__init__()
self.layer_idx = int(prefix.split(sep=".")[-1])
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
if rope_scaling is not None and getattr(
config, "original_max_position_embeddings", None
):
rope_scaling["original_max_position_embeddings"] = (
config.original_max_position_embeddings
)
# Dual attention structure
self.self_attn = nn.ModuleList(
@@ -361,8 +357,6 @@ class FlashDecoderLayer(nn.Module):
config.q_lora_rank if hasattr(config, "q_lora_rank") else None
),
kv_lora_rank=config.kv_lora_rank,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=None