Refactor Transformers backend to use mixins (#26906)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -147,6 +147,10 @@ class ModelConfig:
|
||||
seed: int | None = None
|
||||
"""Random seed for reproducibility. Initialized to None in V0, but
|
||||
initialized to 0 in V1."""
|
||||
hf_config: PretrainedConfig = field(init=False)
|
||||
"""The Hugging Face config of the model."""
|
||||
hf_text_config: PretrainedConfig = field(init=False)
|
||||
"""The Hugging Face config of the text model (same as hf_config for text models)."""
|
||||
hf_config_path: str | None = None
|
||||
"""Name or path of the Hugging Face config to use. If unspecified, model
|
||||
name or path will be used."""
|
||||
@@ -771,8 +775,10 @@ class ModelConfig:
|
||||
def _get_transformers_backend_cls(self) -> str:
|
||||
"""Determine which Transformers backend class will be used if
|
||||
`model_impl` is set to `transformers` or `auto`."""
|
||||
prefix = "Transformers"
|
||||
prefix += "MoE" if self.get_num_experts() > 1 else ""
|
||||
cls = "Transformers"
|
||||
# If 'hf_config != hf_text_config' it's a nested config, i.e. multimodal
|
||||
cls += "MultiModal" if self.hf_config != self.hf_text_config else ""
|
||||
cls += "MoE" if self.get_num_experts() > 1 else ""
|
||||
# Check if the architecture we're wrapping has defaults
|
||||
runner = None
|
||||
convert = None
|
||||
@@ -788,18 +794,15 @@ class ModelConfig:
|
||||
runner = "generate"
|
||||
if convert in {None, "none"}:
|
||||
convert = "embed"
|
||||
# Resolve Transformers backend pooling classes
|
||||
# Resolve Transformers backend task
|
||||
if runner == "pooling":
|
||||
if convert == "embed":
|
||||
return prefix + "EmbeddingModel"
|
||||
return cls + "EmbeddingModel"
|
||||
if convert == "classify":
|
||||
return prefix + "ForSequenceClassification"
|
||||
# Resolve Transformers backend generate classes
|
||||
if self.hf_config != self.hf_text_config:
|
||||
# If 'hf_text_config' is the same as 'hf_config'. If not, it is
|
||||
# probably a composite config, i.e. multimodal
|
||||
return prefix + "ForMultimodalLM"
|
||||
return prefix + "ForCausalLM"
|
||||
return cls + "ForSequenceClassification"
|
||||
else:
|
||||
cls += "ForCausalLM"
|
||||
return cls
|
||||
|
||||
def using_transformers_backend(self) -> bool:
|
||||
"""Check if the model is using the Transformers backend class."""
|
||||
|
||||
Reference in New Issue
Block a user