[VLM] Cleanup validation and update docs (#6149)

This commit is contained in:
Cyrus Leung
2024-07-05 13:49:38 +08:00
committed by GitHub
parent a41357e941
commit ea4b570483
3 changed files with 86 additions and 81 deletions

View File

@@ -149,14 +149,16 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
config.vocab_size, logit_scale)
self.sampler = Sampler()
def _validate_image_data(self, data: torch.Tensor) -> torch.Tensor:
if list(data.shape)[1:] != [
3, self.config.vision_config.image_size,
self.config.vision_config.image_size
]:
def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor:
h = w = self.config.vision_config.image_size
expected_dims = (3, h, w)
actual_dims = tuple(data.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("batch_size", *map(str, expected_dims))
raise ValueError(
"The expected image tensor shape is batch dimension plus "
"channel, height and width.")
f"The expected shape of pixel values is {expected_expr}. "
f"You supplied {tuple(data.shape)}.")
return data
@@ -173,7 +175,7 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
return LlavaImagePixelInputs(
type="pixel_values",
data=self._validate_image_data(pixel_values),
data=self._validate_pixel_values(pixel_values),
)
def _select_image_features(self, image_features: torch.Tensor, *,
@@ -226,18 +228,25 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
One key thing to understand is the `input_ids` already accounts for the
positions of the to-be-inserted image embeddings.
Concretely, consider a text prompt:
"<image>\nUSER: What's the content of the image?\nASSISTANT:".
`"USER: <image>\\nWhat's the content of the image?\\nASSISTANT:"`.
Tokenizer outputs:
[1, 32000, 29871, 13, 11889, 29901, 1724, 29915, 29879, 278,
2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901].
The to-be-inserted image has a size of 576 (24 * 24) along the context
length dimension.
`input_ids` is thus [1, 32000, ..., 32000, 29871, 13, 11889, 29901,
1724, 29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933,
9047, 13566, 29901].
There will be 576 `32000` in the `input_ids`.
(32000 is the token id for `<image>`.)
`[1, 3148, 1001, 29901, 29871, 32000, 29871, 13, 5618, 29915, 29879,
278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566, 29901]`.
To reserve space in KV cache, we have to insert placeholder tokens
before they are inputted to the model, so the input processor prepends
additional image tokens (denoted as `32000`), resulting in:
`[1, 3148, 1001, 29901, 29871, 32000, ..., 32000, 29871, 13, 5618,
29915, 29879, 278, 2793, 310, 278, 1967, 29973, 13, 22933, 9047, 13566,
29901]`.
We insert 575 tokens so that including the original image token in the
input, there are a total of 576 (24 * 24) image tokens, which
corresponds to the number of image tokens inputted to the language
model, i.e. the number of image tokens outputted by the visual encoder.
This way, the `positions` and `attn_metadata` are consistent
with the `input_ids`.
@@ -246,6 +255,9 @@ class LlavaForConditionalGeneration(nn.Module, SupportsVision):
input_ids: Flattened (concatenated) input_ids corresponding to a
batch.
pixel_values: The pixels in each input image.
See also:
:class:`LlavaImageInputs`
"""
image_input = self._parse_and_validate_image_input(**kwargs)