Standardise get_rope to use rope_parameters["partial_rotary_factor"], not rotary_dim (#30389)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -73,14 +73,28 @@ def get_field(cls: ConfigType, name: str) -> Field:
|
||||
)
|
||||
|
||||
|
||||
def getattr_iter(object: object, names: Iterable[str], default: Any) -> Any:
|
||||
def getattr_iter(
|
||||
object: object, names: Iterable[str], default: Any, warn: bool = False
|
||||
) -> Any:
|
||||
"""
|
||||
A helper function that retrieves an attribute from an object which may
|
||||
have multiple possible names. This is useful when fetching attributes from
|
||||
arbitrary `transformers.PretrainedConfig` instances.
|
||||
|
||||
In the case where the first name in `names` is the preferred name, and
|
||||
any other names are deprecated aliases, setting `warn=True` will log a
|
||||
warning when a deprecated name is used.
|
||||
"""
|
||||
for name in names:
|
||||
for i, name in enumerate(names):
|
||||
if hasattr(object, name):
|
||||
if warn and i > 0:
|
||||
logger.warning_once(
|
||||
"%s contains a deprecated attribute name '%s'. "
|
||||
"Please use the preferred attribute name '%s' instead.",
|
||||
type(object).__name__,
|
||||
name,
|
||||
names[0],
|
||||
)
|
||||
return getattr(object, name)
|
||||
return default
|
||||
|
||||
|
||||
Reference in New Issue
Block a user