[Hardware] [Intel GPU] refactor xpu worker/executor (#7686)
This commit is contained in:
@@ -1,16 +1,16 @@
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, PromptAdapterConfig,
|
||||
SchedulerConfig, SpeculativeConfig)
|
||||
ModelConfig, ObservabilityConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig,
|
||||
SpeculativeConfig)
|
||||
from vllm.executor.executor_base import ExecutorAsyncBase
|
||||
from vllm.executor.gpu_executor import GPUExecutor
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||
from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput
|
||||
from vllm.utils import make_async
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -30,6 +30,7 @@ class XPUExecutor(GPUExecutor):
|
||||
lora_config: Optional[LoRAConfig],
|
||||
prompt_adapter_config: Optional[PromptAdapterConfig],
|
||||
speculative_config: Optional[SpeculativeConfig],
|
||||
observability_config: Optional[ObservabilityConfig],
|
||||
) -> None:
|
||||
assert device_config.device_type == "xpu"
|
||||
assert (not speculative_config
|
||||
@@ -46,32 +47,23 @@ class XPUExecutor(GPUExecutor):
|
||||
self.device_config = device_config
|
||||
self.prompt_adapter_config = prompt_adapter_config
|
||||
self.speculative_config = None
|
||||
self.observability_config = observability_config
|
||||
|
||||
# Instantiate the worker and load the model to GPU.
|
||||
self._init_executor()
|
||||
|
||||
def _create_worker(self,
|
||||
local_rank: int = 0,
|
||||
rank: int = 0,
|
||||
distributed_init_method: Optional[str] = None):
|
||||
if self.speculative_config is None:
|
||||
worker_module_name = "vllm.worker.xpu_worker"
|
||||
worker_class_name = "XPUWorker"
|
||||
else:
|
||||
def _get_worker_module_and_class(self) -> Tuple[str, str]:
|
||||
if self.speculative_config is not None:
|
||||
raise NotImplementedError(
|
||||
"XPU does not support speculative decoding")
|
||||
|
||||
wrapper = WorkerWrapperBase(
|
||||
worker_module_name=worker_module_name,
|
||||
worker_class_name=worker_class_name,
|
||||
)
|
||||
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
|
||||
distributed_init_method))
|
||||
return wrapper.worker
|
||||
else:
|
||||
worker_module_name = "vllm.worker.xpu_worker"
|
||||
worker_class_name = "XPUWorker"
|
||||
return (worker_module_name, worker_class_name)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]:
|
||||
output = self.driver_worker.execute_model(execute_model_req)
|
||||
return output
|
||||
|
||||
|
||||
Reference in New Issue
Block a user