[Hardware][TPU] Refactor TPU backend (#5831)
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import List, Set, Tuple
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -26,30 +26,46 @@ class TPUExecutor(ExecutorBase):
|
||||
self.model_config.dtype = torch.bfloat16
|
||||
|
||||
# Instantiate the worker and load the model to the device.
|
||||
self._init_worker()
|
||||
|
||||
def _init_worker(self):
|
||||
from vllm.worker.tpu_worker import TPUWorker
|
||||
|
||||
assert self.parallel_config.world_size == 1, (
|
||||
"TPUExecutor currently only supports a single TPU chip.")
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
self.driver_worker = TPUWorker(
|
||||
self.model_config,
|
||||
self.parallel_config,
|
||||
self.scheduler_config,
|
||||
self.device_config,
|
||||
self.cache_config,
|
||||
self.load_config,
|
||||
self.vision_language_config,
|
||||
local_rank=0,
|
||||
rank=0,
|
||||
distributed_init_method=distributed_init_method,
|
||||
)
|
||||
self.driver_worker = self._create_worker()
|
||||
self.driver_worker.init_device()
|
||||
self.driver_worker.load_model()
|
||||
|
||||
def _get_worker_kwargs(
|
||||
self,
|
||||
local_rank: int = 0,
|
||||
rank: int = 0,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Return worker init args for a given rank."""
|
||||
if distributed_init_method is None:
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
get_ip(), get_open_port())
|
||||
return dict(
|
||||
model_config=self.model_config,
|
||||
parallel_config=self.parallel_config,
|
||||
scheduler_config=self.scheduler_config,
|
||||
device_config=self.device_config,
|
||||
cache_config=self.cache_config,
|
||||
load_config=self.load_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
vision_language_config=self.vision_language_config,
|
||||
is_driver_worker=rank == 0,
|
||||
)
|
||||
|
||||
def _create_worker(
|
||||
self,
|
||||
local_rank: int = 0,
|
||||
rank: int = 0,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
):
|
||||
from vllm.worker.tpu_worker import TPUWorker
|
||||
|
||||
worker = TPUWorker(**self._get_worker_kwargs(local_rank, rank,
|
||||
distributed_init_method))
|
||||
return worker
|
||||
|
||||
def initialize_cache(
|
||||
self,
|
||||
num_gpu_blocks: int,
|
||||
|
||||
Reference in New Issue
Block a user