From f23fb5a7c1b61350c5c40ca1115d3bf8cf2b8cc9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?RickyChen=20/=20=E9=99=B3=E6=98=AD=E5=84=92?= Date: Wed, 21 Jan 2026 15:27:30 +0800 Subject: [PATCH] [Bugfix] Support HF sharded weights for Mistral3/Pixtral models (#32673) Signed-off-by: ricky-chaoju Signed-off-by: vllm-dev --- vllm/model_executor/models/pixtral.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 272af0ae7..b767bc160 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -502,10 +502,12 @@ class PixtralForConditionalGeneration( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]): - return weight[0].startswith("vision_encoder") + return weight[0].startswith(("vision_encoder", "vision_tower")) def is_vision_lang_adapter_weights(weight: tuple[str, torch.Tensor]): - return weight[0].startswith("vision_language_adapter") + return weight[0].startswith( + ("vision_language_adapter", "multi_modal_projector") + ) def is_patch_merger(weight: tuple[str, torch.Tensor]): return weight[0].startswith("patch_merger") @@ -543,9 +545,10 @@ class PixtralForConditionalGeneration( continue # Load vision encoder weights directly trimmed_name = ".".join(name.split(".")[1:]) - param = vision_encoder_dict[trimmed_name] - with torch.no_grad(): - default_weight_loader(param, w) + param = vision_encoder_dict.get(trimmed_name) + if param is not None: + with torch.no_grad(): + default_weight_loader(param, w) elif is_patch_merger((name, w)): if self.patch_merger is None: continue @@ -567,12 +570,15 @@ class PixtralForConditionalGeneration( continue # Load vision-language adapter weights directly trimmed_name = ".".join(name.split(".")[1:]) - param = vision_lang_adapter_dict[trimmed_name] - with torch.no_grad(): - default_weight_loader(param, w) + param = vision_lang_adapter_dict.get(trimmed_name) + if param is not None: + with torch.no_grad(): + default_weight_loader(param, w) else: # LLM weights: yield them to be loaded # by language_model.load_weights + # Strip "language_model." prefix if present (HF sharded format) + name = name.removeprefix("language_model.") yield (name, w) # Now we call the language model load with the generator