[Core][VLM] Support image embeddings as input (#6613)
This commit is contained in:
@@ -27,6 +27,24 @@ from .utils import (filter_weights, init_vllm_registered_model,
|
||||
merge_vision_embeddings)
|
||||
|
||||
|
||||
class LlavaImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: torch.Tensor
|
||||
"""Shape: `(batch_size, num_channels, height, width)`"""
|
||||
|
||||
|
||||
class LlavaImageEmbeddingInputs(TypedDict):
|
||||
type: Literal["image_embeds"]
|
||||
data: torch.Tensor
|
||||
"""Shape: `(batch_size, image_feature_size, hidden_size)`
|
||||
|
||||
`hidden_size` must match the hidden size of language model backbone.
|
||||
"""
|
||||
|
||||
|
||||
LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs]
|
||||
|
||||
|
||||
# TODO(xwjiang): Run benchmark and decide if TP.
|
||||
class LlavaMultiModalProjector(nn.Module):
|
||||
|
||||
@@ -49,15 +67,6 @@ class LlavaMultiModalProjector(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LlavaImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: torch.Tensor
|
||||
"""Shape: `(batch_size, num_channels, height, width)`"""
|
||||
|
||||
|
||||
LlavaImageInputs = LlavaImagePixelInputs
|
||||
|
||||
|
||||
def get_max_llava_image_tokens(ctx: InputContext):
|
||||
hf_config = ctx.get_hf_config(LlavaConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
@@ -210,18 +219,30 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[LlavaImageInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_embeds = kwargs.pop("image_embeds", None)
|
||||
|
||||
if pixel_values is None:
|
||||
if pixel_values is None and image_embeds is None:
|
||||
return None
|
||||
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
if pixel_values is not None:
|
||||
if not isinstance(pixel_values, torch.Tensor):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
return LlavaImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(pixel_values),
|
||||
)
|
||||
|
||||
return LlavaImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(pixel_values),
|
||||
)
|
||||
if image_embeds is not None:
|
||||
if not isinstance(image_embeds, torch.Tensor):
|
||||
raise ValueError("Incorrect type of image embeddings. "
|
||||
f"Got type: {type(image_embeds)}")
|
||||
return LlavaImageEmbeddingInputs(
|
||||
type="image_embeds",
|
||||
data=image_embeds,
|
||||
)
|
||||
|
||||
raise AssertionError("This line should be unreachable.")
|
||||
|
||||
def _select_image_features(self, image_features: torch.Tensor, *,
|
||||
strategy: str) -> torch.Tensor:
|
||||
@@ -258,6 +279,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
|
||||
|
||||
def _process_image_input(self,
|
||||
image_input: LlavaImageInputs) -> torch.Tensor:
|
||||
|
||||
if image_input["type"] == "image_embeds":
|
||||
return image_input["data"]
|
||||
|
||||
assert self.vision_tower is not None
|
||||
image_features = self._process_image_pixels(image_input)
|
||||
return self.multi_modal_projector(image_features)
|
||||
|
||||
Reference in New Issue
Block a user