[Core][VLM] Support image embeddings as input (#6613)

This commit is contained in:
Roger Wang
2024-08-12 01:16:06 -07:00
committed by GitHub
parent ec2affa8ae
commit e6e42e4b17
13 changed files with 517 additions and 138 deletions

View File

@@ -70,6 +70,36 @@ CLIP_VIT_LARGE_PATCH14_336_CONFIG = CLIPVisionConfig(dropout=0.0,
projection_dim=768)
class Phi3VImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different for each batch, in which case
the data is passed as a list instead of a batched tensor.
"""
image_sizes: torch.Tensor
"""
Shape: `(batch_size, 2)`
This should be in `(height, width)` format.
"""
class Phi3VImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""Shape: `(batch_size, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
Phi3VImageInputs = Union[Phi3VImagePixelInputs, Phi3VImageEmbeddingInputs]
class Phi3ImageEmbeddingBase(nn.Module):
def __init__(self) -> None:
@@ -257,24 +287,6 @@ class Phi3HDImageEmbedding(Phi3ImageEmbeddingBase):
return image_features_hd_newline
class Phi3VImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: Union[torch.Tensor, List[torch.Tensor]]
"""
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different for each batch, in which case
the data is passed as a list instead of a batched tensor.
"""
image_sizes: torch.Tensor
"""
Shape: `(batch_size, 2)`
This should be in `(height, width)` format.
"""
# Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57
def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336):
target_height = int(np.ceil(height / padding_unit) * padding_unit)
@@ -390,7 +402,7 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs):
input_width=w,
input_height=h)
elif isinstance(image_data, torch.Tensor):
raise NotImplementedError("Embeddings input is not supported yet")
image_feature_size = image_data.shape[0]
else:
raise TypeError(f"Invalid image type: {type(image_data)}")
@@ -494,25 +506,55 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
return data
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[Phi3VImagePixelInputs]:
self, **kwargs: object) -> Optional[Phi3VImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_sizes = kwargs.pop("image_sizes", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None:
return None
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if pixel_values is None and image_embeds is None:
return None
if not isinstance(image_sizes, torch.Tensor):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")
if pixel_values is not None:
if not isinstance(pixel_values, (torch.Tensor, list)):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return Phi3VImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
image_sizes=self._validate_image_sizes(image_sizes))
if not isinstance(image_sizes, torch.Tensor):
raise ValueError("Incorrect type of image sizes. "
f"Got type: {type(image_sizes)}")
return Phi3VImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
image_sizes=self._validate_image_sizes(image_sizes))
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return Phi3VImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)
raise AssertionError("This line should be unreachable.")
def _process_image_input(
self,
image_input: Phi3VImageInputs,
) -> torch.Tensor:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_embed_tokens is not None
image_embeds = self.vision_embed_tokens(image_input["data"],
image_input["image_sizes"])
return image_embeds
def forward(self,
input_ids: torch.Tensor,
@@ -524,8 +566,7 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
vision_embeddings = self.vision_embed_tokens(
image_input["data"], image_input["image_sizes"])
vision_embeddings = self._process_image_input(image_input)
inputs_embeds = self.model.get_input_embeddings(input_ids)
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
vision_embeddings,