[Hardware] [Intel GPU] refactor xpu worker/executor (#7686)

This commit is contained in:
Kunshang Ji
2024-08-21 00:54:10 +08:00
committed by GitHub
parent aae6927be0
commit c42590f97a
3 changed files with 26 additions and 28 deletions

View File

@@ -1,16 +1,16 @@
from typing import List, Optional
from typing import List, Optional, Tuple, Union
import torch
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ParallelConfig, PromptAdapterConfig,
SchedulerConfig, SpeculativeConfig)
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.executor.executor_base import ExecutorAsyncBase
from vllm.executor.gpu_executor import GPUExecutor
from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest, SamplerOutput
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
from vllm.utils import make_async
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
@@ -30,6 +30,7 @@ class XPUExecutor(GPUExecutor):
lora_config: Optional[LoRAConfig],
prompt_adapter_config: Optional[PromptAdapterConfig],
speculative_config: Optional[SpeculativeConfig],
observability_config: Optional[ObservabilityConfig],
) -> None:
assert device_config.device_type == "xpu"
assert (not speculative_config
@@ -46,32 +47,23 @@ class XPUExecutor(GPUExecutor):
self.device_config = device_config
self.prompt_adapter_config = prompt_adapter_config
self.speculative_config = None
self.observability_config = observability_config
# Instantiate the worker and load the model to GPU.
self._init_executor()
def _create_worker(self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None):
if self.speculative_config is None:
worker_module_name = "vllm.worker.xpu_worker"
worker_class_name = "XPUWorker"
else:
def _get_worker_module_and_class(self) -> Tuple[str, str]:
if self.speculative_config is not None:
raise NotImplementedError(
"XPU does not support speculative decoding")
wrapper = WorkerWrapperBase(
worker_module_name=worker_module_name,
worker_class_name=worker_class_name,
)
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
distributed_init_method))
return wrapper.worker
else:
worker_module_name = "vllm.worker.xpu_worker"
worker_class_name = "XPUWorker"
return (worker_module_name, worker_class_name)
def execute_model(
self,
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
self, execute_model_req: ExecuteModelRequest
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
output = self.driver_worker.execute_model(execute_model_req)
return output