Refactor Transformers backend to use mixins (#26906)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-10-16 22:50:39 +01:00
committed by GitHub
parent b2f78cbad4
commit fb5e10d3fb
17 changed files with 1510 additions and 1248 deletions

View File

@@ -147,6 +147,10 @@ class ModelConfig:
seed: int | None = None
"""Random seed for reproducibility. Initialized to None in V0, but
initialized to 0 in V1."""
hf_config: PretrainedConfig = field(init=False)
"""The Hugging Face config of the model."""
hf_text_config: PretrainedConfig = field(init=False)
"""The Hugging Face config of the text model (same as hf_config for text models)."""
hf_config_path: str | None = None
"""Name or path of the Hugging Face config to use. If unspecified, model
name or path will be used."""
@@ -771,8 +775,10 @@ class ModelConfig:
def _get_transformers_backend_cls(self) -> str:
"""Determine which Transformers backend class will be used if
`model_impl` is set to `transformers` or `auto`."""
prefix = "Transformers"
prefix += "MoE" if self.get_num_experts() > 1 else ""
cls = "Transformers"
# If 'hf_config != hf_text_config' it's a nested config, i.e. multimodal
cls += "MultiModal" if self.hf_config != self.hf_text_config else ""
cls += "MoE" if self.get_num_experts() > 1 else ""
# Check if the architecture we're wrapping has defaults
runner = None
convert = None
@@ -788,18 +794,15 @@ class ModelConfig:
runner = "generate"
if convert in {None, "none"}:
convert = "embed"
# Resolve Transformers backend pooling classes
# Resolve Transformers backend task
if runner == "pooling":
if convert == "embed":
return prefix + "EmbeddingModel"
return cls + "EmbeddingModel"
if convert == "classify":
return prefix + "ForSequenceClassification"
# Resolve Transformers backend generate classes
if self.hf_config != self.hf_text_config:
# If 'hf_text_config' is the same as 'hf_config'. If not, it is
# probably a composite config, i.e. multimodal
return prefix + "ForMultimodalLM"
return prefix + "ForCausalLM"
return cls + "ForSequenceClassification"
else:
cls += "ForCausalLM"
return cls
def using_transformers_backend(self) -> bool:
"""Check if the model is using the Transformers backend class."""