[Core][VLM] Support image embeddings as input (#6613)

This commit is contained in:
Roger Wang
2024-08-12 01:16:06 -07:00
committed by GitHub
parent ec2affa8ae
commit e6e42e4b17
13 changed files with 517 additions and 138 deletions

View File

@@ -234,7 +234,8 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
cache_config=cache_config,
quant_config=quant_config)
def _parse_and_validate_image_input(self, **kwargs: object):
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[FuyuImagePixelInputs]:
image_patches = kwargs.pop("image_patches", None)
if isinstance(image_patches, torch.Tensor):
@@ -249,6 +250,13 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
data=image_patches)
return None
def _process_image_input(
self, image_input: FuyuImagePixelInputs) -> torch.Tensor:
assert self.vision_embed_tokens is not None
vision_embeddings, _ = self.vision_embed_tokens(image_input["data"])
return vision_embeddings
def forward(
self,
input_ids: torch.Tensor,
@@ -261,8 +269,7 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
vision_embeddings, _ = self.vision_embed_tokens(
image_input["data"])
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
vision_embeddings,