[Misc] Add tensor schema test coverage for multimodal models (#21754)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -51,13 +51,14 @@ class DeepseekVL2ImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- bn: Batch size * number of images
|
||||
- p: Number of patches
|
||||
- c: Number of channels (3)
|
||||
- h: Height of each image
|
||||
- w: Width of each image
|
||||
"""
|
||||
type: Literal["pixel_values"]
|
||||
data: Annotated[Union[torch.Tensor, list[torch.Tensor]],
|
||||
TensorShape("bn", 3, "h", "w")]
|
||||
TensorShape("bn", "p", 3, "h", "w", dynamic_dims={"p"})]
|
||||
images_spatial_crop: Annotated[torch.Tensor, TensorShape("bn", 2)]
|
||||
|
||||
|
||||
|
||||
@@ -104,13 +104,16 @@ def smart_resize(
|
||||
class KeyeImagePixelInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- b: Batch size
|
||||
- np: Number of patches
|
||||
- cps: Number of channels * patch_size * patch_size
|
||||
- c: Number of channels
|
||||
- ps: Patch size
|
||||
- ni: Number of images
|
||||
- g: Grid dimensions (3 for t, h, w)
|
||||
"""
|
||||
type: Literal["pixel_values"]
|
||||
pixel_values: Annotated[torch.Tensor, TensorShape("np", "cps")]
|
||||
pixel_values: Annotated[torch.Tensor,
|
||||
TensorShape("b", "np", 3, "ps", "ps")]
|
||||
image_grid_thw: Annotated[torch.Tensor, TensorShape("ni", 3)]
|
||||
|
||||
|
||||
@@ -134,14 +137,16 @@ KeyeImageInputs = Union[KeyeImagePixelInputs, KeyeImageEmbeddingInputs]
|
||||
class KeyeVideoPixelInputs(TensorSchema):
|
||||
"""
|
||||
Dimensions:
|
||||
- b: Batch size
|
||||
- np: Number of patches
|
||||
- ctps: Number of channels * temporal_patch_size * patch_size *
|
||||
patch_size
|
||||
- nv: Number of videos
|
||||
- c: Number of channels
|
||||
- ps: Patch size
|
||||
- ni: Number of images
|
||||
- g: Grid dimensions (3 for t, h, w)
|
||||
"""
|
||||
type: Literal["pixel_values_videos"]
|
||||
pixel_values_videos: Annotated[torch.Tensor, TensorShape("np", "ctps")]
|
||||
pixel_values_videos: Annotated[torch.Tensor,
|
||||
TensorShape("b", "np", 3, "ps", "ps")]
|
||||
video_grid_thw: Annotated[torch.Tensor, TensorShape("nv", 3)]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user