[Doc] Update LLaVA docs (#5437)

Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Cyrus Leung
2024-06-14 02:22:07 +08:00
committed by GitHub
parent 39873476f8
commit 0ce7b952f8
3 changed files with 29 additions and 38 deletions

View File

@@ -227,7 +227,7 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
attn_metadata: AttentionMetadata,
**kwargs: object,
) -> SamplerOutput:
"""Run forward pass for Llava 1.5.
"""Run forward pass for LLaVA-1.5.
One key thing to understand is the `input_ids` already accounts for the
positions of the to-be-inserted image embeddings.
@@ -247,22 +247,25 @@ class LlavaForConditionalGeneration(VisionLanguageModelBase):
This way, the `positions` and `attn_metadata` are consistent
with the `input_ids`.
The model takes two types of image inputs:
PIXEL_VALUES and IMAGE_FEATURES.
The following shows how each maps to huggingface implementation.
PIXEL_VALUES:
- https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L353
IMAGE_FEATURES:
- https://github.com/huggingface/transformers/blob/07bdbeb/src/transformers/models/llava/modeling_llava.py#L430
before going through the multi modal projector.
This model has two modes of image inputs:
`PIXEL_VALUES` and `IMAGE_FEATURES`.
Args:
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values: For PIXEL_VALUES, expects a batch with shape
[1, 3, 336, 336].
image_features: For IMAGE_FEATURES, expects a batch with shape
[1, 576, 1024].
pixel_values: The pixels in each input image.
Expects a batch with shape `[1, 3, 336, 336]`.
(Only applicable to `PIXEL_VALUES` mode)
image_features: The image features for each input image outputted by
the vision tower before passing to the multi-modal projector.
Expects a batch with shape `[1, 576, 1024]`.
(Only applicable to `IMAGE_FEATURES` mode)
See also:
Each input maps to huggingface implementation, as follows:
- `pixel_values`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava/modeling_llava.py#L360
- `image_features`: https://github.com/huggingface/transformers/blob/v4.41.1/src/transformers/models/llava/modeling_llava.py#L437
"""
image_input = self._parse_and_validate_image_input(**kwargs)