[sleep mode] clear pytorch cache after sleep (#15248)
Signed-off-by: <villard@us.ibm.com>
This commit is contained in:
@@ -8,6 +8,7 @@
|
|||||||
# not sure why, they are created from a different context.
|
# not sure why, they are created from a different context.
|
||||||
# the only successful approach is to call cuda driver API in C.
|
# the only successful approach is to call cuda driver API in C.
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import gc
|
||||||
import os
|
import os
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
||||||
@@ -175,7 +176,7 @@ class CuMemAllocator:
|
|||||||
str]] = None) -> None:
|
str]] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Put the allocator in sleep mode.
|
Put the allocator in sleep mode.
|
||||||
All data in the memory allocation with the specified tag will be
|
All data in the memory allocation with the specified tag will be
|
||||||
offloaded to CPU memory, and others will be discarded.
|
offloaded to CPU memory, and others will be discarded.
|
||||||
|
|
||||||
:param offload_tags: The tags of the memory allocation that will be
|
:param offload_tags: The tags of the memory allocation that will be
|
||||||
@@ -204,10 +205,13 @@ class CuMemAllocator:
|
|||||||
data.cpu_backup_tensor = cpu_backup_tensor
|
data.cpu_backup_tensor = cpu_backup_tensor
|
||||||
unmap_and_release(handle)
|
unmap_and_release(handle)
|
||||||
|
|
||||||
|
gc.collect()
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
def wake_up(self):
|
def wake_up(self):
|
||||||
"""
|
"""
|
||||||
Wake up the allocator from sleep mode.
|
Wake up the allocator from sleep mode.
|
||||||
All data that is previously offloaded will be loaded back to GPU
|
All data that is previously offloaded will be loaded back to GPU
|
||||||
memory, and the rest of the data will have empty memory."""
|
memory, and the rest of the data will have empty memory."""
|
||||||
for ptr, data in self.pointer_to_data.items():
|
for ptr, data in self.pointer_to_data.items():
|
||||||
handle = data.handle
|
handle = data.handle
|
||||||
@@ -225,7 +229,7 @@ class CuMemAllocator:
|
|||||||
def use_memory_pool(self, tag: Optional[str] = None):
|
def use_memory_pool(self, tag: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
A context manager to use the memory pool.
|
A context manager to use the memory pool.
|
||||||
All memory allocation created inside the context will be allocated
|
All memory allocation created inside the context will be allocated
|
||||||
in the memory pool, and has the specified tag.
|
in the memory pool, and has the specified tag.
|
||||||
|
|
||||||
:param tag: The tag of the memory allocation. If None, the default tag
|
:param tag: The tag of the memory allocation. If None, the default tag
|
||||||
|
|||||||
Reference in New Issue
Block a user