[Misc] Move weights mapper (#11443)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2024-12-24 21:05:25 +08:00
committed by GitHub
parent 5c7963249d
commit 196c34b0ac
8 changed files with 74 additions and 68 deletions

View File

@@ -408,6 +408,13 @@ class Phi3VMultiModalProcessor(BaseMultiModalProcessor):
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens)
@MULTIMODAL_REGISTRY.register_processor(Phi3VMultiModalProcessor)
class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"model.vision_embed_tokens.wte": "embed_tokens",
"model.vision_embed_tokens.": "vision_embed_tokens.",
"lm_head.": "language_model.lm_head.",
"model.": "language_model.model.",
})
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
@@ -616,17 +623,10 @@ class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP):
def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"model.vision_embed_tokens.wte": "embed_tokens",
"model.vision_embed_tokens.": "vision_embed_tokens.",
"lm_head.": "language_model.lm_head.",
"model.": "language_model.model.",
})
loader = AutoWeightsLoader(self)
autoloaded_weights = loader.load_weights(weights,
mapper=hf_to_vllm_mapper)
mapper=self.hf_to_vllm_mapper)
# The HF config doesn't specify whether these are tied,
# so we detect it this way