[Core] Allow specifying custom Executor (#6557)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import enum
|
||||
import json
|
||||
from dataclasses import dataclass, field, fields
|
||||
from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
@@ -18,7 +18,10 @@ from vllm.utils import (cuda_device_count_stateless, get_cpu_memory, is_cpu,
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.model_executor.model_loader.loader import BaseModelLoader
|
||||
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
|
||||
BaseTokenizerGroup)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -527,11 +530,12 @@ class TokenizerPoolConfig:
|
||||
pool type.
|
||||
"""
|
||||
pool_size: int
|
||||
pool_type: str
|
||||
pool_type: Union[str, Type["BaseTokenizerGroup"]]
|
||||
extra_config: dict
|
||||
|
||||
def __post_init__(self):
|
||||
if self.pool_type not in ("ray", ):
|
||||
if self.pool_type not in ("ray", ) and not isinstance(
|
||||
self.pool_type, type):
|
||||
raise ValueError(f"Unknown pool type: {self.pool_type}")
|
||||
if not isinstance(self.extra_config, dict):
|
||||
raise ValueError("extra_config must be a dictionary.")
|
||||
@@ -661,7 +665,8 @@ class ParallelConfig:
|
||||
tokenizer_pool_config: Optional[TokenizerPoolConfig] = None,
|
||||
ray_workers_use_nsight: bool = False,
|
||||
placement_group: Optional["PlacementGroup"] = None,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
distributed_executor_backend: Optional[Union[
|
||||
str, Type["ExecutorBase"]]] = None,
|
||||
) -> None:
|
||||
self.pipeline_parallel_size = pipeline_parallel_size
|
||||
self.tensor_parallel_size = tensor_parallel_size
|
||||
@@ -676,7 +681,7 @@ class ParallelConfig:
|
||||
if worker_use_ray:
|
||||
if self.distributed_executor_backend is None:
|
||||
self.distributed_executor_backend = "ray"
|
||||
elif self.distributed_executor_backend != "ray":
|
||||
elif not self.use_ray:
|
||||
raise ValueError(f"worker-use-ray can't be used with "
|
||||
f"distributed executor backend "
|
||||
f"'{self.distributed_executor_backend}'.")
|
||||
@@ -711,12 +716,25 @@ class ParallelConfig:
|
||||
self._verify_args()
|
||||
self.rank = 0
|
||||
|
||||
@property
|
||||
def use_ray(self) -> bool:
|
||||
return self.distributed_executor_backend == "ray" or (
|
||||
isinstance(self.distributed_executor_backend, type)
|
||||
and self.distributed_executor_backend.uses_ray)
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
if self.distributed_executor_backend not in ("ray", "mp", None):
|
||||
# Lazy import to avoid circular import
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
|
||||
if self.distributed_executor_backend not in (
|
||||
"ray", "mp", None) and not (isinstance(
|
||||
self.distributed_executor_backend, type) and issubclass(
|
||||
self.distributed_executor_backend, ExecutorBase)):
|
||||
raise ValueError(
|
||||
"Unrecognized distributed executor backend. Supported values "
|
||||
"are 'ray' or 'mp'.")
|
||||
if self.distributed_executor_backend == "ray":
|
||||
"Unrecognized distributed executor backend "
|
||||
f"{self.distributed_executor_backend}. Supported "
|
||||
"values are 'ray', 'mp' or custom ExecutorBase subclass.")
|
||||
if self.use_ray:
|
||||
from vllm.executor import ray_utils
|
||||
ray_utils.assert_ray_available()
|
||||
if is_hip():
|
||||
@@ -724,8 +742,7 @@ class ParallelConfig:
|
||||
logger.info(
|
||||
"Disabled the custom all-reduce kernel because it is not "
|
||||
"supported on AMD GPUs.")
|
||||
if self.ray_workers_use_nsight and (
|
||||
not self.distributed_executor_backend == "ray"):
|
||||
if self.ray_workers_use_nsight and not self.use_ray:
|
||||
raise ValueError("Unable to use nsight profiling unless workers "
|
||||
"run with Ray.")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user