[Bugfix] fix IntermediateTensors equal method (#23027)
Signed-off-by: Andy Xie <andy.xning@gmail.com>
This commit is contained in:
@@ -1163,7 +1163,13 @@ class IntermediateTensors:
|
||||
return len(self.tensors)
|
||||
|
||||
def __eq__(self, other: object):
|
||||
return isinstance(other, self.__class__) and self
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
if self.tensors.keys() != other.tensors.keys():
|
||||
return False
|
||||
return all(
|
||||
torch.equal(self.tensors[k], other.tensors[k])
|
||||
for k in self.tensors)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"IntermediateTensors(tensors={self.tensors})"
|
||||
|
||||
Reference in New Issue
Block a user