[Misc] Allow LM only loading for Pixtral (#29451)
Signed-off-by: Roger Wang <hey@rogerw.io>
This commit is contained in:
@@ -400,21 +400,30 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
|
|||||||
prefix=maybe_prefix(prefix, "language_model"),
|
prefix=maybe_prefix(prefix, "language_model"),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.vision_encoder = VisionTransformer(self.vision_args)
|
if multimodal_config.get_limit_per_prompt("image"):
|
||||||
|
self.vision_encoder = VisionTransformer(self.vision_args)
|
||||||
if self.vision_args.add_pre_mm_projector_layer_norm:
|
self.pre_mm_projector_norm = (
|
||||||
self.pre_mm_projector_norm = RMSNorm(self.vision_args.hidden_size, eps=1e-5)
|
RMSNorm(self.vision_args.hidden_size, eps=1e-5)
|
||||||
|
if self.vision_args.add_pre_mm_projector_layer_norm
|
||||||
if self.vision_args.mm_projector_id == PATCH_MERGE:
|
else None
|
||||||
self.patch_merger = PatchMerger(
|
|
||||||
vision_encoder_dim=self.vision_args.hidden_size,
|
|
||||||
spatial_merge_size=self.vision_args.spatial_merge_size,
|
|
||||||
use_mlp_bias=False,
|
|
||||||
)
|
)
|
||||||
|
self.patch_merger = (
|
||||||
self.vision_language_adapter = VisionLanguageAdapter(
|
PatchMerger(
|
||||||
self.vision_args, dim=config.text_config.hidden_size
|
vision_encoder_dim=self.vision_args.hidden_size,
|
||||||
)
|
spatial_merge_size=self.vision_args.spatial_merge_size,
|
||||||
|
use_mlp_bias=False,
|
||||||
|
)
|
||||||
|
if self.vision_args.mm_projector_id == PATCH_MERGE
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
self.vision_language_adapter = VisionLanguageAdapter(
|
||||||
|
self.vision_args, dim=config.text_config.hidden_size
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.vision_encoder = None
|
||||||
|
self.pre_mm_projector_norm = None
|
||||||
|
self.patch_merger = None
|
||||||
|
self.vision_language_adapter = None
|
||||||
|
|
||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
self.language_model.make_empty_intermediate_tensors
|
self.language_model.make_empty_intermediate_tensors
|
||||||
@@ -436,13 +445,17 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
|
|||||||
self,
|
self,
|
||||||
image_input: PixtralImagePixelInputs,
|
image_input: PixtralImagePixelInputs,
|
||||||
) -> tuple[torch.Tensor, ...]:
|
) -> tuple[torch.Tensor, ...]:
|
||||||
|
assert (
|
||||||
|
self.vision_encoder is not None and self.vision_language_adapter is not None
|
||||||
|
)
|
||||||
|
|
||||||
images = image_input["images"]
|
images = image_input["images"]
|
||||||
image_features = self.vision_encoder(images)
|
image_features = self.vision_encoder(images)
|
||||||
feature_sizes = [image_feature.shape[0] for image_feature in image_features]
|
feature_sizes = [image_feature.shape[0] for image_feature in image_features]
|
||||||
image_features = torch.cat(image_features)
|
image_features = torch.cat(image_features)
|
||||||
if self.vision_args.add_pre_mm_projector_layer_norm:
|
if self.pre_mm_projector_norm is not None:
|
||||||
image_features = self.pre_mm_projector_norm(image_features)
|
image_features = self.pre_mm_projector_norm(image_features)
|
||||||
if self.vision_args.mm_projector_id == PATCH_MERGE:
|
if self.patch_merger is not None:
|
||||||
patch_size = self.vision_args.patch_size
|
patch_size = self.vision_args.patch_size
|
||||||
spatial_merge_size_square = self.vision_args.spatial_merge_size**2
|
spatial_merge_size_square = self.vision_args.spatial_merge_size**2
|
||||||
img_patch_dims = [
|
img_patch_dims = [
|
||||||
@@ -508,41 +521,57 @@ class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP)
|
|||||||
return weight[0].startswith("pre_mm_projector_norm")
|
return weight[0].startswith("pre_mm_projector_norm")
|
||||||
|
|
||||||
# Get references to parameters for direct loading
|
# Get references to parameters for direct loading
|
||||||
vision_encoder_dict = dict(self.vision_encoder.named_parameters())
|
vision_encoder_dict = (
|
||||||
|
dict(self.vision_encoder.named_parameters())
|
||||||
|
if self.vision_encoder is not None
|
||||||
|
else {}
|
||||||
|
)
|
||||||
patch_merger_dict = (
|
patch_merger_dict = (
|
||||||
dict(self.patch_merger.named_parameters())
|
dict(self.patch_merger.named_parameters())
|
||||||
if self.vision_args.mm_projector_id == PATCH_MERGE
|
if self.patch_merger is not None
|
||||||
else dict()
|
else {}
|
||||||
)
|
)
|
||||||
pre_mm_projector_norm_dict = (
|
pre_mm_projector_norm_dict = (
|
||||||
dict(self.pre_mm_projector_norm.named_parameters())
|
dict(self.pre_mm_projector_norm.named_parameters())
|
||||||
if self.vision_args.add_pre_mm_projector_layer_norm
|
if self.pre_mm_projector_norm is not None
|
||||||
else dict()
|
else {}
|
||||||
|
)
|
||||||
|
vision_lang_adapter_dict = (
|
||||||
|
dict(self.vision_language_adapter.named_parameters())
|
||||||
|
if self.vision_language_adapter is not None
|
||||||
|
else {}
|
||||||
)
|
)
|
||||||
vision_lang_adapter_dict = dict(self.vision_language_adapter.named_parameters())
|
|
||||||
|
|
||||||
def llm_weights_generator():
|
def llm_weights_generator():
|
||||||
# Single pass over weights
|
# Single pass over weights
|
||||||
for name, w in weights:
|
for name, w in weights:
|
||||||
if is_vision_encoder_weights((name, w)):
|
if is_vision_encoder_weights((name, w)):
|
||||||
|
if self.vision_encoder is None:
|
||||||
|
continue
|
||||||
# Load vision encoder weights directly
|
# Load vision encoder weights directly
|
||||||
trimmed_name = ".".join(name.split(".")[1:])
|
trimmed_name = ".".join(name.split(".")[1:])
|
||||||
param = vision_encoder_dict[trimmed_name]
|
param = vision_encoder_dict[trimmed_name]
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
default_weight_loader(param, w)
|
default_weight_loader(param, w)
|
||||||
elif is_patch_merger((name, w)):
|
elif is_patch_merger((name, w)):
|
||||||
|
if self.patch_merger is None:
|
||||||
|
continue
|
||||||
# Load vision patch merger weights directly
|
# Load vision patch merger weights directly
|
||||||
trimmed_name = ".".join(name.split(".")[1:])
|
trimmed_name = ".".join(name.split(".")[1:])
|
||||||
param = patch_merger_dict[trimmed_name]
|
param = patch_merger_dict[trimmed_name]
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
default_weight_loader(param, w)
|
default_weight_loader(param, w)
|
||||||
elif is_pre_mm_projector_norm((name, w)):
|
elif is_pre_mm_projector_norm((name, w)):
|
||||||
|
if self.pre_mm_projector_norm is None:
|
||||||
|
continue
|
||||||
# Load vision pre_mm_projector_norm weights directly
|
# Load vision pre_mm_projector_norm weights directly
|
||||||
trimmed_name = ".".join(name.split(".")[1:])
|
trimmed_name = ".".join(name.split(".")[1:])
|
||||||
param = pre_mm_projector_norm_dict[trimmed_name]
|
param = pre_mm_projector_norm_dict[trimmed_name]
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
default_weight_loader(param, w)
|
default_weight_loader(param, w)
|
||||||
elif is_vision_lang_adapter_weights((name, w)):
|
elif is_vision_lang_adapter_weights((name, w)):
|
||||||
|
if self.vision_language_adapter is None:
|
||||||
|
continue
|
||||||
# Load vision-language adapter weights directly
|
# Load vision-language adapter weights directly
|
||||||
trimmed_name = ".".join(name.split(".")[1:])
|
trimmed_name = ".".join(name.split(".")[1:])
|
||||||
param = vision_lang_adapter_dict[trimmed_name]
|
param = vision_lang_adapter_dict[trimmed_name]
|
||||||
|
|||||||
Reference in New Issue
Block a user