[VLM] Cleanup validation and update docs (#6149)

This commit is contained in:
Cyrus Leung
2024-07-05 13:49:38 +08:00
committed by GitHub
parent a41357e941
commit ea4b570483
3 changed files with 86 additions and 81 deletions

View File

@@ -263,7 +263,8 @@ class Phi3VImagePixelInputs(TypedDict):
"""
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
Note that `num_patches` may be different for each batch.
Note that `num_patches` may be different for each batch, in which case
the data is passed as a list instead of a batched tensor.
"""
image_sizes: torch.Tensor
@@ -466,8 +467,8 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
if list(data.shape[1:]) != [2]:
raise ValueError(
f"The expected image sizes shape is batch dimension plus "
f"{[2]}. You supplied {data.shape}.")
f"The expected shape of image sizes is batch dimension plus "
f"{[2]}. You supplied {tuple(data.shape)}.")
return data
@@ -475,19 +476,20 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
self, data: Union[torch.Tensor, List[torch.Tensor]]
) -> Union[torch.Tensor, List[torch.Tensor]]:
def _validate_shape(data: torch.Tensor):
if list(data.shape)[2:] != [
3, CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size
]:
raise ValueError(
"The expected pixel value tensor shape is batch dimension "
"plus patch number, channel, height and width.")
h = w = CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size
expected_dims = (3, h, w)
if isinstance(data, torch.Tensor):
_validate_shape(data)
else:
[_validate_shape(d) for d in data]
def _validate_shape(d: torch.Tensor):
actual_dims = tuple(d.shape[1:])
if actual_dims != expected_dims:
expected_expr = ("num_patches", *map(str, expected_dims))
raise ValueError(
"The expected shape of pixel values in each batch element "
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
for d in data:
_validate_shape(d)
return data