[Model] Consolidate ViTs attention implementation without mask (#10893)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2024-12-05 02:11:08 +08:00
committed by GitHub
parent 01d079fd8e
commit 10398b4706
9 changed files with 107 additions and 224 deletions

View File

@@ -482,6 +482,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
self.mlp1 = self._init_mlp1(config)
self.img_context_token_id = None
self.visual_token_mask = None
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
@@ -635,13 +636,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
return image_embeds
def _get_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
if self.is_mono:
visual_token_mask = (
self.visual_token_mask = (
input_ids == self.img_context_token_id).reshape(-1, 1)
else:
visual_token_mask = None
return visual_token_mask
self.visual_token_mask = None
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
image_input = self._parse_and_validate_image_input(**kwargs)
@@ -658,6 +658,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
if multimodal_embeddings is not None:
assert self.img_context_token_id is not None
self._set_visual_token_mask(input_ids)
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, multimodal_embeddings,
self.img_context_token_id)
@@ -674,7 +675,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
**kwargs: object,
) -> Union[SamplerOutput, IntermediateTensors]:
visual_token_mask = None
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
@@ -695,16 +695,15 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
"intermediate_tensors": intermediate_tensors,
"inputs_embeds": inputs_embeds,
}
if self.img_context_token_id is not None:
visual_token_mask = self._get_visual_token_mask(input_ids)
# We always overwrite it back to None after computing visual token
# mask so that this doesn't need to depend on encoder output
if self.visual_token_mask is not None:
# overwrite visual_token_mask and img_context_token_id back to None,
# so that this doesn't need to depend on encoder output
forward_kwargs.update(
{"visual_token_mask": self.visual_token_mask})
self.visual_token_mask = None
self.img_context_token_id = None
if self.is_mono:
forward_kwargs.update({"visual_token_mask": visual_token_mask})
hidden_states = self.language_model.model(**forward_kwargs)
return hidden_states