Various Transformers v5 config fixes (#38247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -119,6 +119,8 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
|
||||
tarsier2="Tarsier2Config",
|
||||
)
|
||||
|
||||
_SPECULATIVE_DECODING_CONFIGS: set[str] = {"eagle", "speculators"}
|
||||
|
||||
_CONFIG_ATTRS_MAPPING: dict[str, str] = {
|
||||
"llm_config": "text_config",
|
||||
}
|
||||
@@ -190,7 +192,7 @@ class HFConfigParser(ConfigParserBase):
|
||||
dummy_model_type = hf_overrides(dummy_config).model_type
|
||||
model_type = dummy_model_type.removeprefix("dummy_")
|
||||
|
||||
if model_type in _CONFIG_REGISTRY:
|
||||
if model_type in _SPECULATIVE_DECODING_CONFIGS:
|
||||
config_class = _CONFIG_REGISTRY[model_type]
|
||||
config = config_class.from_pretrained(
|
||||
model,
|
||||
@@ -200,6 +202,14 @@ class HFConfigParser(ConfigParserBase):
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
if model_type in _CONFIG_REGISTRY:
|
||||
# Register the config class to AutoConfig to ensure it's used in future
|
||||
# calls to `from_pretrained`
|
||||
config_class = _CONFIG_REGISTRY[model_type]
|
||||
config_class.model_type = model_type
|
||||
AutoConfig.register(model_type, config_class, exist_ok=True)
|
||||
# Now that it is registered, it is not considered remote code anymore
|
||||
trust_remote_code = False
|
||||
try:
|
||||
kwargs = _maybe_update_auto_config_kwargs(kwargs, model_type=model_type)
|
||||
config = AutoConfig.from_pretrained(
|
||||
|
||||
@@ -20,7 +20,6 @@ class ColModernVBertConfig(PretrainedConfig):
|
||||
vlm_config: dict | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
if vlm_config is None:
|
||||
@@ -55,6 +54,7 @@ class ColModernVBertConfig(PretrainedConfig):
|
||||
intermediate_size=vis_cfg.get("intermediate_size", 3072),
|
||||
num_attention_heads=vis_cfg.get("num_attention_heads", 12),
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@property
|
||||
def image_seq_len(self) -> int:
|
||||
|
||||
@@ -87,6 +87,18 @@ class MlpProjectorConfig(PretrainedConfig):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
if hasattr(DeepseekV2Config, "validate"):
|
||||
# Transformers v5
|
||||
from huggingface_hub.dataclasses import strict
|
||||
|
||||
@strict
|
||||
class DeepseekVLV2TextConfig(DeepseekV2Config):
|
||||
kv_lora_rank: int | None = None
|
||||
else:
|
||||
# Transformers v4
|
||||
DeepseekVLV2TextConfig = DeepseekV2Config # type: ignore[misc]
|
||||
|
||||
|
||||
class DeepseekVLV2Config(PretrainedConfig):
|
||||
model_type = "deepseek_vl_v2"
|
||||
architectures: list[str] | None = None
|
||||
@@ -102,22 +114,17 @@ class DeepseekVLV2Config(PretrainedConfig):
|
||||
candidate_resolutions: tuple[tuple[int, int]] = ((384, 384),),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
if "architectures" not in kwargs:
|
||||
kwargs["architectures"] = ["DeepseekVLV2ForCausalLM"]
|
||||
|
||||
if self.architectures is None:
|
||||
self.architectures = ["DeepseekVLV2ForCausalLM"]
|
||||
|
||||
vision_config = kwargs.get("vision_config", {})
|
||||
vision_config = kwargs.pop("vision_config", {})
|
||||
self.vision_config = VisionEncoderConfig(**vision_config)
|
||||
|
||||
projector_config = kwargs.get("projector_config", {})
|
||||
projector_config = kwargs.pop("projector_config", {})
|
||||
self.projector_config = MlpProjectorConfig(**projector_config)
|
||||
|
||||
language_config = kwargs.get("language_config", {})
|
||||
# remove kv_lora_rank if not specified, passing None is prohibited
|
||||
if language_config.get("kv_lora_rank") is None:
|
||||
language_config.pop("kv_lora_rank", None)
|
||||
self.text_config = DeepseekV2Config(**language_config)
|
||||
language_config = kwargs.pop("language_config", {})
|
||||
self.text_config = DeepseekVLV2TextConfig(**language_config)
|
||||
|
||||
self.tile_tag = tile_tag
|
||||
self.global_view_pos = global_view_pos
|
||||
@@ -125,7 +132,8 @@ class DeepseekVLV2Config(PretrainedConfig):
|
||||
self.vocab_size = self.text_config.vocab_size
|
||||
|
||||
# update model_type for OCR models
|
||||
if "DeepseekOCRForCausalLM" in self.architectures:
|
||||
if "DeepseekOCRForCausalLM" in kwargs["architectures"]:
|
||||
self.model_type = "deepseek_ocr"
|
||||
elif "DeepseekOCR2ForCausalLM" in self.architectures:
|
||||
elif "DeepseekOCR2ForCausalLM" in kwargs["architectures"]:
|
||||
self.model_type = "deepseek_ocr2"
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@@ -39,13 +39,6 @@ class FlexOlmoConfig(PretrainedConfig):
|
||||
if "architectures" not in kwargs:
|
||||
kwargs["architectures"] = ["FlexOlmoForCausalLM"]
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
@@ -80,3 +73,10 @@ class FlexOlmoConfig(PretrainedConfig):
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
if self.rope_parameters is not None and "type" in self.rope_parameters:
|
||||
self.rope_parameters["rope_type"] = self.rope_parameters["type"]
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -50,8 +50,6 @@ class IsaacConfig(Qwen3Config):
|
||||
vision_attn_implementation: str | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if isinstance(text_config, dict):
|
||||
# from HF config
|
||||
self.text_config = self.sub_configs["text_config"](**text_config)
|
||||
@@ -92,6 +90,7 @@ class IsaacConfig(Qwen3Config):
|
||||
vision_max_num_patches,
|
||||
)
|
||||
self.vision_attn_implementation = vision_attn_implementation
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -220,7 +220,6 @@ class Qwen3NextConfig(PretrainedConfig):
|
||||
):
|
||||
if mlp_only_layers is None:
|
||||
mlp_only_layers = []
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
@@ -279,6 +278,7 @@ class Qwen3NextConfig(PretrainedConfig):
|
||||
self.output_router_logits = output_router_logits
|
||||
self.router_aux_loss_coef = router_aux_loss_coef
|
||||
self.mlp_only_layers = mlp_only_layers
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
|
||||
|
||||
__all__ = ["Qwen3NextConfig"]
|
||||
|
||||
@@ -80,6 +80,9 @@ class Step3p5Config(PretrainedConfig):
|
||||
|
||||
self.att_impl_type = att_impl_type
|
||||
self.use_head_wise_attn_gate = use_head_wise_attn_gate
|
||||
# For some reason the checkpoint has longer layer_types than num_hidden_layers
|
||||
if layer_types is not None:
|
||||
layer_types = layer_types[: self.num_hidden_layers]
|
||||
self.layer_types = layer_types
|
||||
self.use_rope_layers = use_rope_layers
|
||||
self.yarn_only_types = yarn_only_types
|
||||
|
||||
Reference in New Issue
Block a user