[VLM] Minor space optimization for ClipVisionModel (#6436)

This commit is contained in:
Roger Wang
2024-07-15 02:29:51 -07:00
committed by GitHub
parent 22e79ee8f3
commit 6ae1597ddf
4 changed files with 66 additions and 39 deletions

View File

@@ -128,8 +128,17 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
self.config = config
self.multimodal_config = multimodal_config
# Initialize the vision tower only up to the required feature layer
vision_feature_layer = config.vision_feature_layer
if vision_feature_layer < 0:
num_hidden_layers = config.vision_config.num_hidden_layers \
+ vision_feature_layer + 1
else:
num_hidden_layers = vision_feature_layer + 1
# TODO: Optionally initializes this for supporting embeddings.
self.vision_tower = CLIPVisionModel(config.vision_config)
self.vision_tower = CLIPVisionModel(
config.vision_config, num_hidden_layers_override=num_hidden_layers)
self.multi_modal_projector = LlavaMultiModalProjector(
vision_hidden_size=config.vision_config.hidden_size,
text_hidden_size=config.text_config.hidden_size,
@@ -193,8 +202,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
# NOTE: we skip the step to select the vision feature layer since
# this is already done inside the vision tower
image_features = vision_tower(pixel_values,
self.config.vision_feature_layer)
image_features = vision_tower(pixel_values)
return self._select_image_features(
image_features,
@@ -333,7 +341,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
break
else:
use_default_weight_loading = True
if use_default_weight_loading:
if use_default_weight_loading and name in params_dict:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)