Update rope_scaling to rope_parameters in preparation for Transformers v5 (#28542)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -77,6 +77,7 @@ from vllm.model_executor.models.utils import (
|
||||
sequence_parallel_chunk,
|
||||
)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.transformers_utils.config import set_default_rope_theta
|
||||
|
||||
|
||||
def check_ffn_act_fn(act_fn: str):
|
||||
@@ -259,7 +260,6 @@ class OpenPanguMLAAttention(nn.Module):
|
||||
v_head_dim: int,
|
||||
q_lora_rank: int | None,
|
||||
kv_lora_rank: int,
|
||||
rope_theta: float = 10000,
|
||||
max_position_embeddings: int = 8192,
|
||||
cache_config: CacheConfig | None = None,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
@@ -274,8 +274,6 @@ class OpenPanguMLAAttention(nn.Module):
|
||||
self.v_head_dim = v_head_dim
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.rope_theta = rope_theta
|
||||
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
if num_heads % self.tp_size != 0:
|
||||
raise ValueError(
|
||||
@@ -339,7 +337,9 @@ class OpenPanguMLAAttention(nn.Module):
|
||||
)
|
||||
|
||||
# TODO: remove hard coding
|
||||
rope_scaling = {
|
||||
set_default_rope_theta(config, default_theta=10000)
|
||||
rope_parameters = {
|
||||
"rope_theta": config.rope_parameters["rope_theta"],
|
||||
"beta_fast": 32,
|
||||
"beta_slow": 1,
|
||||
"factor": 1,
|
||||
@@ -353,8 +353,7 @@ class OpenPanguMLAAttention(nn.Module):
|
||||
qk_rope_head_dim,
|
||||
rotary_dim=qk_rope_head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
rope_parameters=rope_parameters,
|
||||
is_neox_style=False,
|
||||
)
|
||||
|
||||
@@ -407,8 +406,6 @@ class OpenPanguEmbeddedAttention(nn.Module):
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: dict[str, Any] | None = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
bias: bool = False,
|
||||
@@ -454,7 +451,6 @@ class OpenPanguEmbeddedAttention(nn.Module):
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
@@ -475,9 +471,7 @@ class OpenPanguEmbeddedAttention(nn.Module):
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
self._init_rotary_emb(
|
||||
config, rope_scaling=rope_scaling, quant_config=quant_config
|
||||
)
|
||||
self._init_rotary_emb(config, quant_config=quant_config)
|
||||
|
||||
if hasattr(config, "interleaved_sliding_window"):
|
||||
interleaved_sliding_window = config.interleaved_sliding_window
|
||||
@@ -521,7 +515,6 @@ class OpenPanguEmbeddedAttention(nn.Module):
|
||||
def _init_rotary_emb(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
rope_scaling: dict[str, Any] | None,
|
||||
quant_config: QuantizationConfig | None,
|
||||
) -> None:
|
||||
is_neox_style = True
|
||||
@@ -533,8 +526,7 @@ class OpenPanguEmbeddedAttention(nn.Module):
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=self.max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
rope_parameters=config.rope_parameters,
|
||||
is_neox_style=is_neox_style,
|
||||
)
|
||||
|
||||
@@ -555,7 +547,6 @@ class OpenPanguDecoderLayer(nn.Module):
|
||||
parallel_config = vllm_config.parallel_config
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
||||
|
||||
layer_idx = int(prefix.split(sep=".")[-1])
|
||||
@@ -579,7 +570,6 @@ class OpenPanguDecoderLayer(nn.Module):
|
||||
config.q_lora_rank if hasattr(config, "q_lora_rank") else None
|
||||
),
|
||||
kv_lora_rank=config.kv_lora_rank,
|
||||
rope_theta=rope_theta,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
@@ -607,8 +597,6 @@ class OpenPanguDecoderLayer(nn.Module):
|
||||
num_kv_heads=getattr(
|
||||
config, "num_key_value_heads", config.num_attention_heads
|
||||
),
|
||||
rope_theta=rope_theta,
|
||||
rope_scaling=getattr(config, "rope_scaling", None),
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
quant_config=quant_config,
|
||||
bias=attention_bias,
|
||||
|
||||
Reference in New Issue
Block a user