[VLM] Minor space optimization for ClipVisionModel (#6436)
This commit is contained in:
@@ -80,13 +80,11 @@ class Phi3ImageEmbeddingBase(nn.Module):
|
||||
|
||||
def get_img_features(self,
|
||||
img_embeds: torch.FloatTensor) -> torch.FloatTensor:
|
||||
LAYER_IDX = self.layer_idx
|
||||
TYPE_FEATURE = self.type_feature
|
||||
|
||||
# NOTE: we skip the step to select the vision feature layer since
|
||||
# this is already done inside the img_processor
|
||||
img_feature = self.img_processor(img_embeds,
|
||||
vision_feature_layer=LAYER_IDX)
|
||||
img_feature = self.img_processor(img_embeds)
|
||||
|
||||
if TYPE_FEATURE == "patch":
|
||||
patch_feature = img_feature[:, 1:]
|
||||
@@ -111,7 +109,17 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
||||
config, 'n_embd') else config.hidden_size
|
||||
|
||||
clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG
|
||||
self.img_processor = CLIPVisionModel(clip_config)
|
||||
self.layer_idx = config.img_processor.get('layer_idx', -2)
|
||||
|
||||
# Initialize the CLIP only up to the required feature layer
|
||||
if self.layer_idx < 0:
|
||||
num_hidden_layers = clip_config.num_hidden_layers + \
|
||||
self.layer_idx + 1
|
||||
else:
|
||||
num_hidden_layers = self.layer_idx + 1
|
||||
|
||||
self.img_processor = CLIPVisionModel(
|
||||
clip_config, num_hidden_layers_override=num_hidden_layers)
|
||||
image_dim_out = config.img_processor['image_dim_out']
|
||||
self.num_img_tokens = config.img_processor['num_img_tokens']
|
||||
|
||||
@@ -142,8 +150,6 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
|
||||
self.img_projection = nn.Sequential(*layers)
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.layer_idx = config.img_processor.get('layer_idx', -2)
|
||||
self.type_feature = config.img_processor.get('type_feature', 'patch')
|
||||
|
||||
def forward(self, input_ids: torch.LongTensor,
|
||||
@@ -588,7 +594,8 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
if name in params_dict:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
Reference in New Issue
Block a user