[Bugfix] fix IntermediateTensors equal method (#23027)

Signed-off-by: Andy Xie <andy.xning@gmail.com>
This commit is contained in:
Ning Xie
2025-08-18 17:58:11 +08:00
committed by GitHub
parent 27e8d1ea3e
commit 5a30bd10d8
2 changed files with 45 additions and 3 deletions

View File

@@ -1163,7 +1163,13 @@ class IntermediateTensors:
return len(self.tensors)
def __eq__(self, other: object):
return isinstance(other, self.__class__) and self
if not isinstance(other, self.__class__):
return False
if self.tensors.keys() != other.tensors.keys():
return False
return all(
torch.equal(self.tensors[k], other.tensors[k])
for k in self.tensors)
def __repr__(self) -> str:
return f"IntermediateTensors(tensors={self.tensors})"