[Core][VLM] Support image embeddings as input (#6613)
This commit is contained in:
@@ -234,7 +234,8 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def _parse_and_validate_image_input(self, **kwargs: object):
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[FuyuImagePixelInputs]:
|
||||
image_patches = kwargs.pop("image_patches", None)
|
||||
|
||||
if isinstance(image_patches, torch.Tensor):
|
||||
@@ -249,6 +250,13 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
|
||||
data=image_patches)
|
||||
return None
|
||||
|
||||
def _process_image_input(
|
||||
self, image_input: FuyuImagePixelInputs) -> torch.Tensor:
|
||||
|
||||
assert self.vision_embed_tokens is not None
|
||||
vision_embeddings, _ = self.vision_embed_tokens(image_input["data"])
|
||||
return vision_embeddings
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
@@ -261,8 +269,7 @@ class FuyuForCausalLM(nn.Module, SupportsVision):
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
if image_input is not None:
|
||||
vision_embeddings, _ = self.vision_embed_tokens(
|
||||
image_input["data"])
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
||||
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
|
||||
vision_embeddings,
|
||||
|
||||
Reference in New Issue
Block a user