[Bugfix] Support HF sharded weights for Mistral3/Pixtral models (#32673)
Signed-off-by: ricky-chaoju <ricky.chen@infinirc.com> Signed-off-by: vllm-dev <ricky.chen@infinirc.com>
This commit is contained in:
@@ -502,10 +502,12 @@ class PixtralForConditionalGeneration(
|
|||||||
|
|
||||||
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
||||||
def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]):
|
def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]):
|
||||||
return weight[0].startswith("vision_encoder")
|
return weight[0].startswith(("vision_encoder", "vision_tower"))
|
||||||
|
|
||||||
def is_vision_lang_adapter_weights(weight: tuple[str, torch.Tensor]):
|
def is_vision_lang_adapter_weights(weight: tuple[str, torch.Tensor]):
|
||||||
return weight[0].startswith("vision_language_adapter")
|
return weight[0].startswith(
|
||||||
|
("vision_language_adapter", "multi_modal_projector")
|
||||||
|
)
|
||||||
|
|
||||||
def is_patch_merger(weight: tuple[str, torch.Tensor]):
|
def is_patch_merger(weight: tuple[str, torch.Tensor]):
|
||||||
return weight[0].startswith("patch_merger")
|
return weight[0].startswith("patch_merger")
|
||||||
@@ -543,9 +545,10 @@ class PixtralForConditionalGeneration(
|
|||||||
continue
|
continue
|
||||||
# Load vision encoder weights directly
|
# Load vision encoder weights directly
|
||||||
trimmed_name = ".".join(name.split(".")[1:])
|
trimmed_name = ".".join(name.split(".")[1:])
|
||||||
param = vision_encoder_dict[trimmed_name]
|
param = vision_encoder_dict.get(trimmed_name)
|
||||||
with torch.no_grad():
|
if param is not None:
|
||||||
default_weight_loader(param, w)
|
with torch.no_grad():
|
||||||
|
default_weight_loader(param, w)
|
||||||
elif is_patch_merger((name, w)):
|
elif is_patch_merger((name, w)):
|
||||||
if self.patch_merger is None:
|
if self.patch_merger is None:
|
||||||
continue
|
continue
|
||||||
@@ -567,12 +570,15 @@ class PixtralForConditionalGeneration(
|
|||||||
continue
|
continue
|
||||||
# Load vision-language adapter weights directly
|
# Load vision-language adapter weights directly
|
||||||
trimmed_name = ".".join(name.split(".")[1:])
|
trimmed_name = ".".join(name.split(".")[1:])
|
||||||
param = vision_lang_adapter_dict[trimmed_name]
|
param = vision_lang_adapter_dict.get(trimmed_name)
|
||||||
with torch.no_grad():
|
if param is not None:
|
||||||
default_weight_loader(param, w)
|
with torch.no_grad():
|
||||||
|
default_weight_loader(param, w)
|
||||||
else:
|
else:
|
||||||
# LLM weights: yield them to be loaded
|
# LLM weights: yield them to be loaded
|
||||||
# by language_model.load_weights
|
# by language_model.load_weights
|
||||||
|
# Strip "language_model." prefix if present (HF sharded format)
|
||||||
|
name = name.removeprefix("language_model.")
|
||||||
yield (name, w)
|
yield (name, w)
|
||||||
|
|
||||||
# Now we call the language model load with the generator
|
# Now we call the language model load with the generator
|
||||||
|
|||||||
Reference in New Issue
Block a user