[V1] Refactor get_executor_cls (#11754)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Tuple
|
||||
from typing import Tuple, Type
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
@@ -8,6 +8,23 @@ from vllm.v1.outputs import ModelRunnerOutput
|
||||
class Executor(ABC):
|
||||
"""Abstract class for executors."""
|
||||
|
||||
@staticmethod
|
||||
def get_class(vllm_config: VllmConfig) -> Type["Executor"]:
|
||||
executor_class: Type[Executor]
|
||||
distributed_executor_backend = (
|
||||
vllm_config.parallel_config.distributed_executor_backend)
|
||||
if distributed_executor_backend == "ray":
|
||||
from vllm.v1.executor.ray_executor import RayExecutor
|
||||
executor_class = RayExecutor
|
||||
elif distributed_executor_backend == "mp":
|
||||
from vllm.v1.executor.multiproc_executor import MultiprocExecutor
|
||||
executor_class = MultiprocExecutor
|
||||
else:
|
||||
assert (distributed_executor_backend is None)
|
||||
from vllm.v1.executor.uniproc_executor import UniprocExecutor
|
||||
executor_class = UniprocExecutor
|
||||
return executor_class
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, vllm_config: VllmConfig) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
Reference in New Issue
Block a user