[TPU] Support single and multi-host TPUs on GKE (#7613)
This commit is contained in:
@@ -71,6 +71,19 @@ class RayTPUExecutor(TPUExecutor):
|
||||
worker_module_name = "vllm.worker.tpu_worker"
|
||||
worker_class_name = "TPUWorker"
|
||||
|
||||
# GKE does not fetch environment information from metadata server
|
||||
# and instead sets these from within the Ray process. Therefore we
|
||||
# need to override the Ray environment variables manually.
|
||||
override_env = {}
|
||||
if "TPU_CHIPS_PER_HOST_BOUNDS" in os.environ:
|
||||
override_env.update({
|
||||
"TPU_CHIPS_PER_HOST_BOUNDS":
|
||||
os.environ["TPU_CHIPS_PER_HOST_BOUNDS"]
|
||||
})
|
||||
if "TPU_HOST_BOUNDS" in os.environ:
|
||||
override_env.update(
|
||||
{"TPU_HOST_BOUNDS": os.environ["TPU_HOST_BOUNDS"]})
|
||||
|
||||
worker = ray.remote(
|
||||
num_cpus=0,
|
||||
resources={"TPU": 1},
|
||||
@@ -81,6 +94,8 @@ class RayTPUExecutor(TPUExecutor):
|
||||
worker_class_name=worker_class_name,
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
)
|
||||
if override_env:
|
||||
worker.override_env_vars.remote(override_env)
|
||||
|
||||
worker_ip = ray.get(worker.get_node_ip.remote())
|
||||
if worker_ip == driver_ip and self.driver_dummy_worker is None:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
@@ -84,6 +85,9 @@ try:
|
||||
|
||||
return output
|
||||
|
||||
def override_env_vars(self, vars: Dict[str, str]):
|
||||
os.environ.update(vars)
|
||||
|
||||
ray_import_err = None
|
||||
|
||||
except ImportError as e:
|
||||
@@ -291,3 +295,28 @@ def initialize_ray_cluster(
|
||||
_verify_bundles(current_placement_group, parallel_config, device_str)
|
||||
# Set the placement group in the parallel config
|
||||
parallel_config.placement_group = current_placement_group
|
||||
|
||||
|
||||
def get_num_tpu_nodes() -> int:
|
||||
from ray._private.accelerators import TPUAcceleratorManager
|
||||
cluster_resources = ray.cluster_resources()
|
||||
total_tpus = int(cluster_resources["TPU"])
|
||||
tpus_per_node = TPUAcceleratorManager.get_current_node_num_accelerators()
|
||||
assert total_tpus % tpus_per_node == 0
|
||||
return total_tpus // tpus_per_node
|
||||
|
||||
|
||||
def get_num_nodes_in_placement_group() -> int:
|
||||
pg_table = ray.util.placement_group_table()
|
||||
current_pg = ray.util.get_current_placement_group()
|
||||
num_nodes = 0
|
||||
|
||||
if current_pg:
|
||||
nodes_in_pg = set()
|
||||
for pg_key, pg in pg_table.items():
|
||||
if pg_key == current_pg.id.hex():
|
||||
for _, node in pg["bundles_to_node_id"].items():
|
||||
nodes_in_pg.add(node)
|
||||
num_nodes = len(nodes_in_pg)
|
||||
|
||||
return num_nodes
|
||||
|
||||
Reference in New Issue
Block a user