[Bugfix] Update multimodel models mapping to fit new checkpoint after Transformers v4.52 (#19151)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2025-06-17 23:58:38 +08:00
committed by GitHub
parent 5a1c2e15d8
commit ca94d7fa00
12 changed files with 304 additions and 75 deletions

View File

@@ -40,8 +40,9 @@ from .clip import CLIPVisionModel
from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .pixtral import PixtralHFEncoderInfo, PixtralHFVisionModel
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, flatten_bn, init_vllm_registered_model,
maybe_prefix, merge_multimodal_embeddings)
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
from .vision import get_vision_encoder_info
@@ -499,6 +500,15 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
"gate_up_proj": ["gate_proj", "up_proj"]
}
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
# mapping for new names in checkpoint saved after transformers v4.52
"model.language_model.": "language_model.model.",
"model.vision_tower.": "vision_tower.",
"model.multi_modal_projector.": "multi_modal_projector.",
"lm_head.": "language_model.lm_head.",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
@@ -754,7 +764,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self)
return loader.load_weights(weights)
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
class MantisProcessingInfo(LlavaProcessingInfo):