[Model] Consolidate ViTs attention implementation without mask (#10893)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -482,6 +482,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
self.mlp1 = self._init_mlp1(config)
|
||||
|
||||
self.img_context_token_id = None
|
||||
self.visual_token_mask = None
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
@@ -635,13 +636,12 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
|
||||
return image_embeds
|
||||
|
||||
def _get_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
def _set_visual_token_mask(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
if self.is_mono:
|
||||
visual_token_mask = (
|
||||
self.visual_token_mask = (
|
||||
input_ids == self.img_context_token_id).reshape(-1, 1)
|
||||
else:
|
||||
visual_token_mask = None
|
||||
return visual_token_mask
|
||||
self.visual_token_mask = None
|
||||
|
||||
def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
@@ -658,6 +658,7 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
if multimodal_embeddings is not None:
|
||||
assert self.img_context_token_id is not None
|
||||
self._set_visual_token_mask(input_ids)
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, multimodal_embeddings,
|
||||
self.img_context_token_id)
|
||||
@@ -674,7 +675,6 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
**kwargs: object,
|
||||
) -> Union[SamplerOutput, IntermediateTensors]:
|
||||
|
||||
visual_token_mask = None
|
||||
if intermediate_tensors is not None:
|
||||
input_ids = None
|
||||
inputs_embeds = None
|
||||
@@ -695,16 +695,15 @@ class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
"intermediate_tensors": intermediate_tensors,
|
||||
"inputs_embeds": inputs_embeds,
|
||||
}
|
||||
if self.img_context_token_id is not None:
|
||||
visual_token_mask = self._get_visual_token_mask(input_ids)
|
||||
|
||||
# We always overwrite it back to None after computing visual token
|
||||
# mask so that this doesn't need to depend on encoder output
|
||||
if self.visual_token_mask is not None:
|
||||
# overwrite visual_token_mask and img_context_token_id back to None,
|
||||
# so that this doesn't need to depend on encoder output
|
||||
forward_kwargs.update(
|
||||
{"visual_token_mask": self.visual_token_mask})
|
||||
self.visual_token_mask = None
|
||||
self.img_context_token_id = None
|
||||
|
||||
if self.is_mono:
|
||||
forward_kwargs.update({"visual_token_mask": visual_token_mask})
|
||||
|
||||
hidden_states = self.language_model.model(**forward_kwargs)
|
||||
return hidden_states
|
||||
|
||||
|
||||
Reference in New Issue
Block a user