[Core][AMD] Migrate fully transparent sleep mode to ROCm platform (#12695)

Signed-off-by: Hollow Man <hollowman@opensuse.org>
Signed-off-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
Co-authored-by: kliuae <kuanfu.liu@embeddedllm.com>
This commit is contained in:
ℍ𝕠𝕝𝕝𝕠𝕨 𝕄𝕒𝕟
2025-11-13 01:24:12 +02:00
committed by GitHub
parent 10f01d5a3a
commit 4ca5cd5740
11 changed files with 582 additions and 31 deletions

View File

@@ -264,7 +264,8 @@ class ModelConfig:
merged with the default config from the model. If used with
`--generation-config vllm`, only the override parameters are used."""
enable_sleep_mode: bool = False
"""Enable sleep mode for the engine (only cuda platform is supported)."""
"""Enable sleep mode for the engine (only cuda and
hip platforms are supported)."""
model_impl: str | ModelImpl = "auto"
"""Which implementation of the model to use:\n
- "auto" will try to use the vLLM implementation, if it exists, and fall

View File

@@ -63,7 +63,7 @@ try:
libcudart = CudaRTLibrary()
cumem_available = True
except ModuleNotFoundError:
# rocm platform does not support cumem allocator
# only cuda and rocm platforms support cumem allocator
init_module = None
python_create_and_map = None
python_unmap_and_release = None

View File

@@ -14,6 +14,7 @@ import torch # noqa
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.platforms import current_platform
logger = init_logger(__name__)
@@ -105,6 +106,20 @@ class CudaRTLibrary:
),
]
# https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Runtime_API_functions_supported_by_HIP.html # noqa
cuda_to_hip_mapping = {
"cudaSetDevice": "hipSetDevice",
"cudaDeviceSynchronize": "hipDeviceSynchronize",
"cudaDeviceReset": "hipDeviceReset",
"cudaGetErrorString": "hipGetErrorString",
"cudaMalloc": "hipMalloc",
"cudaFree": "hipFree",
"cudaMemset": "hipMemset",
"cudaMemcpy": "hipMemcpy",
"cudaIpcGetMemHandle": "hipIpcGetMemHandle",
"cudaIpcOpenMemHandle": "hipIpcOpenMemHandle",
}
# class attribute to store the mapping from the path to the library
# to avoid loading the same library multiple times
path_to_library_cache: dict[str, Any] = {}
@@ -117,7 +132,13 @@ class CudaRTLibrary:
if so_file is None:
so_file = find_loaded_library("libcudart")
if so_file is None:
so_file = envs.VLLM_CUDART_SO_PATH # fallback to env var
# libcudart is not loaded in the current process, try hip
so_file = find_loaded_library("libamdhip64")
# should be safe to assume now that we are using ROCm
# as the following assertion should error out if the
# libhiprtc library is also not loaded
if so_file is None:
so_file = envs.VLLM_CUDART_SO_PATH # fallback to env var
assert so_file is not None, (
"libcudart is not loaded in the current process, "
"try setting VLLM_CUDART_SO_PATH"
@@ -130,7 +151,12 @@ class CudaRTLibrary:
if so_file not in CudaRTLibrary.path_to_dict_mapping:
_funcs = {}
for func in CudaRTLibrary.exported_functions:
f = getattr(self.lib, func.name)
f = getattr(
self.lib,
CudaRTLibrary.cuda_to_hip_mapping[func.name]
if current_platform.is_rocm()
else func.name,
)
f.restype = func.restype
f.argtypes = func.argtypes
_funcs[func.name] = f

View File

@@ -18,6 +18,7 @@ if TYPE_CHECKING:
VLLM_RINGBUFFER_WARNING_INTERVAL: int = 60
VLLM_NCCL_SO_PATH: str | None = None
LD_LIBRARY_PATH: str | None = None
VLLM_ROCM_SLEEP_MEM_CHUNK_SIZE: int = 256
VLLM_V1_USE_PREFILL_DECODE_ATTENTION: bool = False
VLLM_FLASH_ATTN_VERSION: int | None = None
LOCAL_RANK: int = 0
@@ -520,6 +521,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
# when `VLLM_NCCL_SO_PATH` is not set, vllm will try to find the nccl
# library file in the locations specified by `LD_LIBRARY_PATH`
"LD_LIBRARY_PATH": lambda: os.environ.get("LD_LIBRARY_PATH", None),
# flag to control the chunk size (in MB) for sleeping memory allocations under ROCm
"VLLM_ROCM_SLEEP_MEM_CHUNK_SIZE": lambda: int(
os.environ.get("VLLM_ROCM_SLEEP_MEM_CHUNK_SIZE", "256")
),
# Use separate prefill and decode kernels for V1 attention instead of
# the unified triton kernel.
"VLLM_V1_USE_PREFILL_DECODE_ATTENTION": lambda: (

View File

@@ -171,7 +171,11 @@ class Platform:
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
def is_sleep_mode_available(self) -> bool:
return self._enum == PlatformEnum.CUDA
# TODO: Actually only mi3xx has the sleep mode support now
# for ROCm, but currently we don't have a way to detect the
# exact GPU model statelessly here. So we return True for
# all ROCm platforms for now.
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)
@classmethod
def device_id_to_physical_device_id(cls, device_id: int):