Add explicit pooling classes for the Transformers backend (#25322)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
import ast
|
||||
import inspect
|
||||
import textwrap
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import MISSING, Field, field, fields, is_dataclass, replace
|
||||
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
|
||||
|
||||
@@ -52,6 +53,18 @@ def get_field(cls: ConfigType, name: str) -> Field:
|
||||
f"{cls.__name__}.{name} must have a default value or default factory.")
|
||||
|
||||
|
||||
def getattr_iter(object: object, names: Iterable[str], default: Any) -> Any:
|
||||
"""
|
||||
A helper function that retrieves an attribute from an object which may
|
||||
have multiple possible names. This is useful when fetching attributes from
|
||||
arbitrary `transformers.PretrainedConfig` instances.
|
||||
"""
|
||||
for name in names:
|
||||
if hasattr(object, name):
|
||||
return getattr(object, name)
|
||||
return default
|
||||
|
||||
|
||||
def contains_object_print(text: str) -> bool:
|
||||
"""
|
||||
Check if the text looks like a printed Python object, e.g.
|
||||
|
||||
Reference in New Issue
Block a user