[VLM] Cleanup validation and update docs (#6149)
This commit is contained in:
@@ -263,7 +263,8 @@ class Phi3VImagePixelInputs(TypedDict):
|
||||
"""
|
||||
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
|
||||
|
||||
Note that `num_patches` may be different for each batch.
|
||||
Note that `num_patches` may be different for each batch, in which case
|
||||
the data is passed as a list instead of a batched tensor.
|
||||
"""
|
||||
|
||||
image_sizes: torch.Tensor
|
||||
@@ -466,8 +467,8 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
|
||||
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
|
||||
if list(data.shape[1:]) != [2]:
|
||||
raise ValueError(
|
||||
f"The expected image sizes shape is batch dimension plus "
|
||||
f"{[2]}. You supplied {data.shape}.")
|
||||
f"The expected shape of image sizes is batch dimension plus "
|
||||
f"{[2]}. You supplied {tuple(data.shape)}.")
|
||||
|
||||
return data
|
||||
|
||||
@@ -475,19 +476,20 @@ class Phi3VForCausalLM(nn.Module, SupportsVision):
|
||||
self, data: Union[torch.Tensor, List[torch.Tensor]]
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
|
||||
def _validate_shape(data: torch.Tensor):
|
||||
if list(data.shape)[2:] != [
|
||||
3, CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size,
|
||||
CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size
|
||||
]:
|
||||
raise ValueError(
|
||||
"The expected pixel value tensor shape is batch dimension "
|
||||
"plus patch number, channel, height and width.")
|
||||
h = w = CLIP_VIT_LARGE_PATCH14_336_CONFIG.image_size
|
||||
expected_dims = (3, h, w)
|
||||
|
||||
if isinstance(data, torch.Tensor):
|
||||
_validate_shape(data)
|
||||
else:
|
||||
[_validate_shape(d) for d in data]
|
||||
def _validate_shape(d: torch.Tensor):
|
||||
actual_dims = tuple(d.shape[1:])
|
||||
|
||||
if actual_dims != expected_dims:
|
||||
expected_expr = ("num_patches", *map(str, expected_dims))
|
||||
raise ValueError(
|
||||
"The expected shape of pixel values in each batch element "
|
||||
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
|
||||
|
||||
for d in data:
|
||||
_validate_shape(d)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
Reference in New Issue
Block a user