[Core][VLM] Support image embeddings as input (#6613)

This commit is contained in:
Roger Wang
2024-08-12 01:16:06 -07:00
committed by GitHub
parent ec2affa8ae
commit e6e42e4b17
13 changed files with 517 additions and 138 deletions

View File

@@ -27,6 +27,24 @@ from .utils import (filter_weights, init_vllm_registered_model,
merge_vision_embeddings)
class LlavaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: `(batch_size, num_channels, height, width)`"""
class LlavaImageEmbeddingInputs(TypedDict):
type: Literal["image_embeds"]
data: torch.Tensor
"""Shape: `(batch_size, image_feature_size, hidden_size)`
`hidden_size` must match the hidden size of language model backbone.
"""
LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]
# TODO(xwjiang): Run benchmark and decide if TP.
class LlavaMultiModalProjector(nn.Module):
@@ -49,15 +67,6 @@ class LlavaMultiModalProjector(nn.Module):
return hidden_states
class LlavaImagePixelInputs(TypedDict):
type: Literal["pixel_values"]
data: torch.Tensor
"""Shape: `(batch_size, num_channels, height, width)`"""
LlavaImageInputs = LlavaImagePixelInputs
def get_max_llava_image_tokens(ctx: InputContext):
hf_config = ctx.get_hf_config(LlavaConfig)
vision_config = hf_config.vision_config
@@ -210,18 +219,30 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
def _parse_and_validate_image_input(
self, **kwargs: object) -> Optional[LlavaImageInputs]:
pixel_values = kwargs.pop("pixel_values", None)
image_embeds = kwargs.pop("image_embeds", None)
if pixel_values is None:
if pixel_values is None and image_embeds is None:
return None
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
if pixel_values is not None:
if not isinstance(pixel_values, torch.Tensor):
raise ValueError("Incorrect type of pixel values. "
f"Got type: {type(pixel_values)}")
return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)
return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_pixel_values(pixel_values),
)
if image_embeds is not None:
if not isinstance(image_embeds, torch.Tensor):
raise ValueError("Incorrect type of image embeddings. "
f"Got type: {type(image_embeds)}")
return LlavaImageEmbeddingInputs(
type="image_embeds",
data=image_embeds,
)
raise AssertionError("This line should be unreachable.")
def _select_image_features(self, image_features: torch.Tensor, *,
strategy: str) -> torch.Tensor:
@@ -258,6 +279,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
def _process_image_input(self,
image_input: LlavaImageInputs) -> torch.Tensor:
if image_input["type"] == "image_embeds":
return image_input["data"]
assert self.vision_tower is not None
image_features = self._process_image_pixels(image_input)
return self.multi_modal_projector(image_features)