Various Transformers v5 config fixes (#38247)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2026-03-26 23:06:59 +00:00
committed by GitHub
parent 28048bd6b0
commit f73bcb1c51
7 changed files with 45 additions and 25 deletions

View File

@@ -119,6 +119,8 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
tarsier2="Tarsier2Config",
)
_SPECULATIVE_DECODING_CONFIGS: set[str] = {"eagle", "speculators"}
_CONFIG_ATTRS_MAPPING: dict[str, str] = {
"llm_config": "text_config",
}
@@ -190,7 +192,7 @@ class HFConfigParser(ConfigParserBase):
dummy_model_type = hf_overrides(dummy_config).model_type
model_type = dummy_model_type.removeprefix("dummy_")
if model_type in _CONFIG_REGISTRY:
if model_type in _SPECULATIVE_DECODING_CONFIGS:
config_class = _CONFIG_REGISTRY[model_type]
config = config_class.from_pretrained(
model,
@@ -200,6 +202,14 @@ class HFConfigParser(ConfigParserBase):
**kwargs,
)
else:
if model_type in _CONFIG_REGISTRY:
# Register the config class to AutoConfig to ensure it's used in future
# calls to `from_pretrained`
config_class = _CONFIG_REGISTRY[model_type]
config_class.model_type = model_type
AutoConfig.register(model_type, config_class, exist_ok=True)
# Now that it is registered, it is not considered remote code anymore
trust_remote_code = False
try:
kwargs = _maybe_update_auto_config_kwargs(kwargs, model_type=model_type)
config = AutoConfig.from_pretrained(

View File

@@ -20,7 +20,6 @@ class ColModernVBertConfig(PretrainedConfig):
vlm_config: dict | None = None,
**kwargs,
):
super().__init__(**kwargs)
self.embedding_dim = embedding_dim
if vlm_config is None:
@@ -55,6 +54,7 @@ class ColModernVBertConfig(PretrainedConfig):
intermediate_size=vis_cfg.get("intermediate_size", 3072),
num_attention_heads=vis_cfg.get("num_attention_heads", 12),
)
super().__init__(**kwargs)
@property
def image_seq_len(self) -> int:

View File

@@ -87,6 +87,18 @@ class MlpProjectorConfig(PretrainedConfig):
super().__init__(**kwargs)
if hasattr(DeepseekV2Config, "validate"):
# Transformers v5
from huggingface_hub.dataclasses import strict
@strict
class DeepseekVLV2TextConfig(DeepseekV2Config):
kv_lora_rank: int | None = None
else:
# Transformers v4
DeepseekVLV2TextConfig = DeepseekV2Config # type: ignore[misc]
class DeepseekVLV2Config(PretrainedConfig):
model_type = "deepseek_vl_v2"
architectures: list[str] | None = None
@@ -102,22 +114,17 @@ class DeepseekVLV2Config(PretrainedConfig):
candidate_resolutions: tuple[tuple[int, int]] = ((384, 384),),
**kwargs,
):
super().__init__(**kwargs)
if "architectures" not in kwargs:
kwargs["architectures"] = ["DeepseekVLV2ForCausalLM"]
if self.architectures is None:
self.architectures = ["DeepseekVLV2ForCausalLM"]
vision_config = kwargs.get("vision_config", {})
vision_config = kwargs.pop("vision_config", {})
self.vision_config = VisionEncoderConfig(**vision_config)
projector_config = kwargs.get("projector_config", {})
projector_config = kwargs.pop("projector_config", {})
self.projector_config = MlpProjectorConfig(**projector_config)
language_config = kwargs.get("language_config", {})
# remove kv_lora_rank if not specified, passing None is prohibited
if language_config.get("kv_lora_rank") is None:
language_config.pop("kv_lora_rank", None)
self.text_config = DeepseekV2Config(**language_config)
language_config = kwargs.pop("language_config", {})
self.text_config = DeepseekVLV2TextConfig(**language_config)
self.tile_tag = tile_tag
self.global_view_pos = global_view_pos
@@ -125,7 +132,8 @@ class DeepseekVLV2Config(PretrainedConfig):
self.vocab_size = self.text_config.vocab_size
# update model_type for OCR models
if "DeepseekOCRForCausalLM" in self.architectures:
if "DeepseekOCRForCausalLM" in kwargs["architectures"]:
self.model_type = "deepseek_ocr"
elif "DeepseekOCR2ForCausalLM" in self.architectures:
elif "DeepseekOCR2ForCausalLM" in kwargs["architectures"]:
self.model_type = "deepseek_ocr2"
super().__init__(**kwargs)

View File

@@ -39,13 +39,6 @@ class FlexOlmoConfig(PretrainedConfig):
if "architectures" not in kwargs:
kwargs["architectures"] = ["FlexOlmoForCausalLM"]
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
@@ -80,3 +73,10 @@ class FlexOlmoConfig(PretrainedConfig):
# BC: if there is a 'type' field, move it to 'rope_type'.
if self.rope_parameters is not None and "type" in self.rope_parameters:
self.rope_parameters["rope_type"] = self.rope_parameters["type"]
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)

View File

@@ -50,8 +50,6 @@ class IsaacConfig(Qwen3Config):
vision_attn_implementation: str | None = None,
**kwargs,
):
super().__init__(**kwargs)
if isinstance(text_config, dict):
# from HF config
self.text_config = self.sub_configs["text_config"](**text_config)
@@ -92,6 +90,7 @@ class IsaacConfig(Qwen3Config):
vision_max_num_patches,
)
self.vision_attn_implementation = vision_attn_implementation
super().__init__(**kwargs)
__all__ = [

View File

@@ -220,7 +220,6 @@ class Qwen3NextConfig(PretrainedConfig):
):
if mlp_only_layers is None:
mlp_only_layers = []
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
@@ -279,6 +278,7 @@ class Qwen3NextConfig(PretrainedConfig):
self.output_router_logits = output_router_logits
self.router_aux_loss_coef = router_aux_loss_coef
self.mlp_only_layers = mlp_only_layers
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
__all__ = ["Qwen3NextConfig"]

View File

@@ -80,6 +80,9 @@ class Step3p5Config(PretrainedConfig):
self.att_impl_type = att_impl_type
self.use_head_wise_attn_gate = use_head_wise_attn_gate
# For some reason the checkpoint has longer layer_types than num_hidden_layers
if layer_types is not None:
layer_types = layer_types[: self.num_hidden_layers]
self.layer_types = layer_types
self.use_rope_layers = use_rope_layers
self.yarn_only_types = yarn_only_types