[BUGFIX] Pixtral cannot be loaded with --limit-mm-per-prompt 0 (#33406)
Signed-off-by: juliendenize <julien.denize@mistral.ai>
This commit is contained in:
@@ -70,7 +70,7 @@ from .interfaces import (
|
|||||||
SupportsPP,
|
SupportsPP,
|
||||||
)
|
)
|
||||||
from .module_mapping import MultiModelKeys
|
from .module_mapping import MultiModelKeys
|
||||||
from .utils import init_vllm_registered_model, maybe_prefix
|
from .utils import StageMissingLayer, init_vllm_registered_model, maybe_prefix
|
||||||
from .vision import (
|
from .vision import (
|
||||||
VisionEncoderInfo,
|
VisionEncoderInfo,
|
||||||
VisionFeatureSelectStrategy,
|
VisionFeatureSelectStrategy,
|
||||||
@@ -93,6 +93,10 @@ except ImportError:
|
|||||||
PATCH_MERGE = "patch_merge"
|
PATCH_MERGE = "patch_merge"
|
||||||
|
|
||||||
|
|
||||||
|
def _is_layer_none_or_staged(layer: nn.Module) -> bool:
|
||||||
|
return layer is None or isinstance(layer, StageMissingLayer)
|
||||||
|
|
||||||
|
|
||||||
class PixtralImagePixelInputs(TensorSchema):
|
class PixtralImagePixelInputs(TensorSchema):
|
||||||
"""
|
"""
|
||||||
Dimensions:
|
Dimensions:
|
||||||
@@ -542,7 +546,7 @@ class PixtralForConditionalGeneration(
|
|||||||
# Single pass over weights
|
# Single pass over weights
|
||||||
for name, w in weights:
|
for name, w in weights:
|
||||||
if is_vision_encoder_weights((name, w)):
|
if is_vision_encoder_weights((name, w)):
|
||||||
if self.vision_encoder is None:
|
if _is_layer_none_or_staged(self.vision_encoder):
|
||||||
continue
|
continue
|
||||||
# Load vision encoder weights directly
|
# Load vision encoder weights directly
|
||||||
trimmed_name = ".".join(name.split(".")[1:])
|
trimmed_name = ".".join(name.split(".")[1:])
|
||||||
@@ -551,7 +555,7 @@ class PixtralForConditionalGeneration(
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
default_weight_loader(param, w)
|
default_weight_loader(param, w)
|
||||||
elif is_patch_merger((name, w)):
|
elif is_patch_merger((name, w)):
|
||||||
if self.patch_merger is None:
|
if _is_layer_none_or_staged(self.patch_merger):
|
||||||
continue
|
continue
|
||||||
# Load vision patch merger weights directly
|
# Load vision patch merger weights directly
|
||||||
trimmed_name = ".".join(name.split(".")[1:])
|
trimmed_name = ".".join(name.split(".")[1:])
|
||||||
@@ -559,7 +563,7 @@ class PixtralForConditionalGeneration(
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
default_weight_loader(param, w)
|
default_weight_loader(param, w)
|
||||||
elif is_pre_mm_projector_norm((name, w)):
|
elif is_pre_mm_projector_norm((name, w)):
|
||||||
if self.pre_mm_projector_norm is None:
|
if _is_layer_none_or_staged(self.pre_mm_projector_norm):
|
||||||
continue
|
continue
|
||||||
# Load vision pre_mm_projector_norm weights directly
|
# Load vision pre_mm_projector_norm weights directly
|
||||||
trimmed_name = ".".join(name.split(".")[1:])
|
trimmed_name = ".".join(name.split(".")[1:])
|
||||||
@@ -567,7 +571,7 @@ class PixtralForConditionalGeneration(
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
default_weight_loader(param, w)
|
default_weight_loader(param, w)
|
||||||
elif is_vision_lang_adapter_weights((name, w)):
|
elif is_vision_lang_adapter_weights((name, w)):
|
||||||
if self.vision_language_adapter is None:
|
if _is_layer_none_or_staged(self.vision_language_adapter):
|
||||||
continue
|
continue
|
||||||
# Load vision-language adapter weights directly
|
# Load vision-language adapter weights directly
|
||||||
trimmed_name = ".".join(name.split(".")[1:])
|
trimmed_name = ".".join(name.split(".")[1:])
|
||||||
|
|||||||
Reference in New Issue
Block a user