[Multimodal] Fix nested_tensors_equal: add length check for lists and tuple support (#38388)
Signed-off-by: khairulkabir1661 <khairulkabir1661@users.noreply.github.com> Co-authored-by: khairulkabir1661 <khairulkabir1661@users.noreply.github.com>
This commit is contained in:
@@ -238,12 +238,29 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool:
|
||||
return isinstance(a, torch.Tensor) and torch.equal(b, a)
|
||||
|
||||
if isinstance(a, list):
|
||||
return isinstance(b, list) and all(
|
||||
nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)
|
||||
return (
|
||||
isinstance(b, list)
|
||||
and len(a) == len(b)
|
||||
and all(nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b))
|
||||
)
|
||||
if isinstance(b, list):
|
||||
return isinstance(a, list) and all(
|
||||
nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)
|
||||
return (
|
||||
isinstance(a, list)
|
||||
and len(b) == len(a)
|
||||
and all(nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a))
|
||||
)
|
||||
|
||||
if isinstance(a, tuple):
|
||||
return (
|
||||
isinstance(b, tuple)
|
||||
and len(a) == len(b)
|
||||
and all(nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b))
|
||||
)
|
||||
if isinstance(b, tuple):
|
||||
return (
|
||||
isinstance(a, tuple)
|
||||
and len(b) == len(a)
|
||||
and all(nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a))
|
||||
)
|
||||
|
||||
# Both a and b are scalars
|
||||
|
||||
Reference in New Issue
Block a user