diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 750893272..12356b872 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -238,12 +238,29 @@ def nested_tensors_equal(a: NestedTensors, b: NestedTensors) -> bool: return isinstance(a, torch.Tensor) and torch.equal(b, a) if isinstance(a, list): - return isinstance(b, list) and all( - nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b) + return ( + isinstance(b, list) + and len(a) == len(b) + and all(nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)) ) if isinstance(b, list): - return isinstance(a, list) and all( - nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a) + return ( + isinstance(a, list) + and len(b) == len(a) + and all(nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)) + ) + + if isinstance(a, tuple): + return ( + isinstance(b, tuple) + and len(a) == len(b) + and all(nested_tensors_equal(a_, b_) for a_, b_ in zip(a, b)) + ) + if isinstance(b, tuple): + return ( + isinstance(a, tuple) + and len(b) == len(a) + and all(nested_tensors_equal(b_, a_) for b_, a_ in zip(b, a)) ) # Both a and b are scalars