[Performance] Split FlashAttn attention and cache update (#25954)
Signed-off-by: ElizaWszola <ewszola@redhat.com> Signed-off-by: mgoin <mgoin64@gmail.com> Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Luka Govedič <luka.govedic@gmail.com> Signed-off-by: Lucas Wilkinson <lwilkins@redhat.com> Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Luka Govedič <lgovedic@redhat.com> Co-authored-by: mgoin <mgoin64@gmail.com> Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Matthew Bonanni <mbonanni@redhat.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: Luka Govedič <luka.govedic@gmail.com> Co-authored-by: Lucas Wilkinson <lwilkins@redhat.com> Co-authored-by: Luka Govedič <lgovedic@redhat.com>
This commit is contained in:
@@ -613,8 +613,9 @@ def weak_ref_tensor(tensor: Any) -> Any:
|
||||
Create a weak reference to a tensor.
|
||||
The new tensor will share the same data as the original tensor,
|
||||
but will not keep the original tensor alive.
|
||||
This ignores 0-size tensors as those don't allocate any memory.
|
||||
"""
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
if isinstance(tensor, torch.Tensor) and tensor.numel() > 0:
|
||||
return torch.ops._C.weak_ref_tensor(tensor)
|
||||
else:
|
||||
return tensor
|
||||
|
||||
Reference in New Issue
Block a user