[Multimodal] Fix nested_tensors_equal: add length check for lists and tuple support (#38388)

Signed-off-by: khairulkabir1661 <khairulkabir1661@users.noreply.github.com>
Co-authored-by: khairulkabir1661 <khairulkabir1661@users.noreply.github.com>
This commit is contained in:
Khairul Kabir
2026-04-08 21:40:37 -07:00
committed by GitHub
parent 2e98406048
commit 490f17d0c7

View File

@@ -238,12 +238,29 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
return isinstance(a, torch.Tensor) and torch.equal(b, a)
if isinstance(a, list):
return isinstance(b, list) and all(
nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)
return (
isinstance(b, list)
and len(a) == len(b)
and all(nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b))
)
if isinstance(b, list):
return isinstance(a, list) and all(
nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)
return (
isinstance(a, list)
and len(b) == len(a)
and all(nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a))
)
if isinstance(a, tuple):
return (
isinstance(b, tuple)
and len(a) == len(b)
and all(nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b))
)
if isinstance(b, tuple):
return (
isinstance(a, tuple)
and len(b) == len(a)
and all(nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a))
)
# Both a and b are scalars