Update rope_scaling to rope_parameters in preparation for Transformers v5 (#28542)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -5,11 +5,11 @@ from typing import NamedTuple
|
||||
import pytest
|
||||
import torch
|
||||
from packaging.version import Version
|
||||
from transformers import AutoConfig
|
||||
from transformers import __version__ as TRANSFORMERS_VERSION
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.config import get_config
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
@@ -98,8 +98,7 @@ def test_mrope(
|
||||
atol = model_info.atol
|
||||
rtol = model_info.rtol
|
||||
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
config = config.get_text_config()
|
||||
config = get_config(model_name, False).get_text_config()
|
||||
|
||||
# get the model config
|
||||
total_num_kv_heads = config.num_key_value_heads
|
||||
@@ -113,7 +112,6 @@ def test_mrope(
|
||||
)
|
||||
is_neox_style = True
|
||||
|
||||
rope_theta = config.rope_theta
|
||||
max_position = config.max_position_embeddings
|
||||
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
|
||||
rotary_dim = int(head_dim * partial_rotary_factor)
|
||||
@@ -122,9 +120,8 @@ def test_mrope(
|
||||
head_size=head_dim,
|
||||
rotary_dim=rotary_dim,
|
||||
max_position=max_position,
|
||||
base=rope_theta,
|
||||
is_neox_style=is_neox_style,
|
||||
rope_scaling=config.rope_scaling,
|
||||
rope_parameters=config.rope_parameters,
|
||||
dtype=dtype,
|
||||
).to(device=device)
|
||||
|
||||
@@ -173,8 +170,7 @@ def test_mrope_torch_compile_tracing(
|
||||
atol = model_info.atol
|
||||
rtol = model_info.rtol
|
||||
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
config = config.get_text_config()
|
||||
config = get_config(model_name, False).get_text_config()
|
||||
|
||||
# get the model config
|
||||
total_num_kv_heads = config.num_key_value_heads
|
||||
@@ -187,7 +183,6 @@ def test_mrope_torch_compile_tracing(
|
||||
else config.hidden_size // total_num_heads
|
||||
)
|
||||
is_neox_style = True
|
||||
rope_theta = config.rope_theta
|
||||
max_position = config.max_position_embeddings
|
||||
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1.0)
|
||||
rotary_dim = int(head_dim * partial_rotary_factor)
|
||||
@@ -196,9 +191,8 @@ def test_mrope_torch_compile_tracing(
|
||||
head_size=head_dim,
|
||||
rotary_dim=rotary_dim,
|
||||
max_position=max_position,
|
||||
base=rope_theta,
|
||||
is_neox_style=is_neox_style,
|
||||
rope_scaling=config.rope_scaling,
|
||||
rope_parameters=config.rope_parameters,
|
||||
dtype=dtype,
|
||||
).to(device=device)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user