[Model][LoRA]LoRA support added for MiniCPMV2.6 (#8943)

Co-authored-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Jee Jee Li
2024-09-30 12:31:55 +08:00
committed by GitHub
parent b6d7392579
commit 8e60afa15e
3 changed files with 49 additions and 880 deletions

View File

@@ -65,11 +65,10 @@ class Idefics2VisionEmbeddings(nn.Module):
self.position_embedding = nn.Embedding(self.num_positions,
self.embed_dim)
def forward(
self,
pixel_values: torch.FloatTensor,
patch_attention_mask: torch.BoolTensor,
) -> torch.Tensor:
def forward(self,
pixel_values: torch.FloatTensor,
patch_attention_mask: torch.BoolTensor,
tgt_sizes: Optional[torch.IntTensor] = None) -> torch.Tensor:
batch_size, _, max_im_h, max_im_w = pixel_values.shape
patch_embeds = self.patch_embedding(pixel_values)
embeddings = patch_embeds.flatten(2).transpose(1, 2)
@@ -84,8 +83,13 @@ class Idefics2VisionEmbeddings(nn.Module):
fill_value=0)
for batch_idx, p_attn_mask in enumerate(patch_attention_mask):
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()
if tgt_sizes is not None:
nb_patches_h = tgt_sizes[batch_idx][0]
nb_patches_w = tgt_sizes[batch_idx][1]
else:
nb_patches_h = p_attn_mask[:, 0].sum()
nb_patches_w = p_attn_mask[0].sum()
fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h)
fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w)
bucket_coords_h = torch.bucketize(fractional_coords_h,
@@ -287,10 +291,12 @@ class Idefics2VisionTransformer(nn.Module):
self,
pixel_values,
patch_attention_mask: Optional[torch.BoolTensor] = None,
) -> torch.tensor:
tgt_sizes: Optional[torch.IntTensor] = None,
) -> torch.Tensor:
hidden_states = self.embeddings(
pixel_values=pixel_values,
patch_attention_mask=patch_attention_mask)
patch_attention_mask=patch_attention_mask,
tgt_sizes=tgt_sizes)
encoder_outputs = self.encoder(hidden_states)
last_hidden_state = self.post_layernorm(encoder_outputs)
return last_hidden_state