[BugFix][Multi Modal] Fix TensorSchema shape mismatch in Molmo (#24559)

Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
Wenlong Wang
2025-09-10 06:14:27 -07:00
committed by GitHub
parent f36355abfd
commit 4c04eef706

View File

@@ -76,20 +76,22 @@ class MolmoImageInputs(TensorSchema):
"""
Dimensions:
- bn: Batch size * number of images
- nc: Number of crops
- nc: Number of crops (dynamic)
- np: Number of patches
- tp: Token sequence positions
- pd: Patch dimension
"""
images: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "nc", "np", "pd")]
TensorShape("bn", "nc", "np", "pd", dynamic_dims={"nc"})]
# Number of crops may vary per batch and image, so pass it as a list.
image_masks: Annotated[Optional[Union[torch.Tensor, list[torch.Tensor]]],
TensorShape("bn", "nc", "np")]
TensorShape("bn", "nc", "np", dynamic_dims={"nc"})]
feat_is_patch: Annotated[Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "nc", "np")]
feat_is_patch: Annotated[
Union[torch.Tensor, list[torch.Tensor]],
TensorShape("bn", "nc", "tp", dynamic_dims={"nc"})]
# A boolean mask indicating which image features correspond to patch tokens.
num_crops: Annotated[torch.Tensor, TensorShape("bn")]