[platforms] refactor cpu code (#10402)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -2,9 +2,6 @@ import os
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Awaitable, List, Optional, Set, Tuple, Union
|
from typing import Any, Awaitable, List, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import vllm.envs as envs
|
|
||||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
|
||||||
SchedulerConfig)
|
|
||||||
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
|
||||||
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
|
from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper,
|
||||||
ResultHandler, WorkerMonitor)
|
ResultHandler, WorkerMonitor)
|
||||||
@@ -13,7 +10,7 @@ from vllm.lora.request import LoRARequest
|
|||||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||||
from vllm.sequence import ExecuteModelRequest
|
from vllm.sequence import ExecuteModelRequest
|
||||||
from vllm.utils import (GiB_bytes, get_distributed_init_method, get_open_port,
|
from vllm.utils import (get_distributed_init_method, get_open_port,
|
||||||
get_vllm_instance_id, make_async)
|
get_vllm_instance_id, make_async)
|
||||||
from vllm.worker.worker_base import WorkerWrapperBase
|
from vllm.worker.worker_base import WorkerWrapperBase
|
||||||
|
|
||||||
@@ -57,13 +54,6 @@ class CPUExecutor(ExecutorBase):
|
|||||||
os.environ["LOCAL_WORLD_SIZE"] = str(
|
os.environ["LOCAL_WORLD_SIZE"] = str(
|
||||||
self.parallel_config.tensor_parallel_size)
|
self.parallel_config.tensor_parallel_size)
|
||||||
|
|
||||||
self.model_config = _verify_and_get_model_config(self.model_config)
|
|
||||||
self.cache_config = _verify_and_get_cache_config(self.cache_config)
|
|
||||||
self.scheduler_config = _verify_and_get_scheduler_config(
|
|
||||||
self.scheduler_config)
|
|
||||||
self.parallel_config = _verify_and_get_parallel_config(
|
|
||||||
self.parallel_config)
|
|
||||||
|
|
||||||
# Multiprocessing-based executor does not support multi-node setting.
|
# Multiprocessing-based executor does not support multi-node setting.
|
||||||
# Since it only works for single node, we can use the loopback address
|
# Since it only works for single node, we can use the loopback address
|
||||||
# 127.0.0.1 for communication.
|
# 127.0.0.1 for communication.
|
||||||
@@ -313,62 +303,6 @@ class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase):
|
|||||||
self.check_health()
|
self.check_health()
|
||||||
|
|
||||||
|
|
||||||
def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig:
|
|
||||||
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
|
|
||||||
# If the feature combo become valid
|
|
||||||
if not config.enforce_eager:
|
|
||||||
logger.warning(
|
|
||||||
"CUDA graph is not supported on CPU, fallback to the eager "
|
|
||||||
"mode.")
|
|
||||||
config.enforce_eager = True
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def _verify_and_get_scheduler_config(
|
|
||||||
config: SchedulerConfig) -> SchedulerConfig:
|
|
||||||
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
|
|
||||||
# If the feature combo become valid
|
|
||||||
if config.chunked_prefill_enabled:
|
|
||||||
logger.warning("Chunked prefill is not supported on CPU, disable it.")
|
|
||||||
config.chunked_prefill_enabled = False
|
|
||||||
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig:
|
|
||||||
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
|
|
||||||
# If the feature combo become valid
|
|
||||||
if config.enable_prefix_caching:
|
|
||||||
logger.warning("Prefix caching is not supported on CPU, disable it.")
|
|
||||||
config.enable_prefix_caching = False
|
|
||||||
|
|
||||||
kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
|
|
||||||
|
|
||||||
if kv_cache_space >= 0:
|
|
||||||
if kv_cache_space == 0:
|
|
||||||
config.cpu_kvcache_space_bytes = 4 * GiB_bytes # type: ignore
|
|
||||||
logger.warning("Environment variable VLLM_CPU_KVCACHE_SPACE (GB) "
|
|
||||||
"for CPU backend is not set, using 4 by default.")
|
|
||||||
else:
|
|
||||||
config.cpu_kvcache_space_bytes = kv_cache_space * GiB_bytes # type: ignore
|
|
||||||
else:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Invalid environment variable VLLM_CPU_KVCACHE_SPACE"
|
|
||||||
f" {kv_cache_space}, expect a positive integer value.")
|
|
||||||
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def _verify_and_get_parallel_config(config: ParallelConfig) -> ParallelConfig:
|
|
||||||
if (config.distributed_executor_backend is not None
|
|
||||||
and config.distributed_executor_backend != "mp"):
|
|
||||||
logger.warning(
|
|
||||||
"%s is not supported on CPU, fallback to mp distributed executor "
|
|
||||||
"backend.", config.distributed_executor_backend)
|
|
||||||
config.distributed_executor_backend = "mp"
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def _driver_method_invoker(driver, method: str, *args, **kwargs):
|
def _driver_method_invoker(driver, method: str, *args, **kwargs):
|
||||||
return getattr(driver, method)(*args, **kwargs)
|
return getattr(driver, method)(*args, **kwargs)
|
||||||
|
|
||||||
|
|||||||
@@ -1,8 +1,19 @@
|
|||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
from .interface import Platform, PlatformEnum
|
from .interface import Platform, PlatformEnum
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
else:
|
||||||
|
VllmConfig = None
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CpuPlatform(Platform):
|
class CpuPlatform(Platform):
|
||||||
_enum = PlatformEnum.CPU
|
_enum = PlatformEnum.CPU
|
||||||
@@ -18,3 +29,52 @@ class CpuPlatform(Platform):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def inference_mode(cls):
|
def inference_mode(cls):
|
||||||
return torch.no_grad()
|
return torch.no_grad()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.utils import GiB_bytes
|
||||||
|
model_config = vllm_config.model_config
|
||||||
|
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
|
||||||
|
# If the feature combo become valid
|
||||||
|
if not model_config.enforce_eager:
|
||||||
|
logger.warning(
|
||||||
|
"CUDA graph is not supported on CPU, fallback to the eager "
|
||||||
|
"mode.")
|
||||||
|
model_config.enforce_eager = True
|
||||||
|
|
||||||
|
cache_config = vllm_config.cache_config
|
||||||
|
|
||||||
|
if cache_config.enable_prefix_caching:
|
||||||
|
logger.warning(
|
||||||
|
"Prefix caching is not supported on CPU, disable it.")
|
||||||
|
cache_config.enable_prefix_caching = False
|
||||||
|
|
||||||
|
kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE
|
||||||
|
|
||||||
|
if kv_cache_space >= 0:
|
||||||
|
if kv_cache_space == 0:
|
||||||
|
cache_config.cpu_kvcache_space_bytes = 4 * GiB_bytes # type: ignore
|
||||||
|
logger.warning(
|
||||||
|
"Environment variable VLLM_CPU_KVCACHE_SPACE (GB) "
|
||||||
|
"for CPU backend is not set, using 4 by default.")
|
||||||
|
else:
|
||||||
|
cache_config.cpu_kvcache_space_bytes = kv_cache_space * GiB_bytes # type: ignore # noqa
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Invalid environment variable VLLM_CPU_KVCACHE_SPACE"
|
||||||
|
f" {kv_cache_space}, expect a positive integer value.")
|
||||||
|
|
||||||
|
scheduler_config = vllm_config.scheduler_config
|
||||||
|
if scheduler_config.chunked_prefill_enabled:
|
||||||
|
logger.warning(
|
||||||
|
"Chunked prefill is not supported on CPU, disable it.")
|
||||||
|
scheduler_config.chunked_prefill_enabled = False
|
||||||
|
|
||||||
|
parallel_config = vllm_config.parallel_config
|
||||||
|
if (parallel_config.distributed_executor_backend is not None
|
||||||
|
and parallel_config.distributed_executor_backend != "mp"):
|
||||||
|
logger.warning(("%s is not supported on CPU, fallback to mp "
|
||||||
|
"distributed executor backend."),
|
||||||
|
parallel_config.distributed_executor_backend)
|
||||||
|
parallel_config.distributed_executor_backend = "mp"
|
||||||
|
|||||||
Reference in New Issue
Block a user