[VLM][Core] Fix exceptions on ragged NestedTensors (#7974)

This commit is contained in:
Peter Salas
2024-08-28 20:24:31 -07:00
committed by GitHub
parent a7f65c2be9
commit 74d5543ec5
3 changed files with 21 additions and 11 deletions

View File

@@ -81,3 +81,15 @@ def test_multimodal_input_batch_multiple_batchable_lists():
result,
{"image": torch.stack([torch.stack([a, b]),
torch.stack([c, d])])})
def test_multimodal_input_batch_mixed_stacking_depths():
a = torch.rand([1, 2, 3])
b = torch.rand([1, 3, 3])
c = torch.rand([1, 4, 3])
result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c]}])
assert_multimodal_inputs_equal(result, {"image": [[a, b], c.unsqueeze(0)]})
result = MultiModalInputs.batch([{"image": [a]}, {"image": [b, c]}])
assert_multimodal_inputs_equal(result, {"image": [a.unsqueeze(0), [b, c]]})