[TPU] Support single and multi-host TPUs on GKE (#7613)

This commit is contained in:
Richard Liu
2024-08-30 00:27:40 -07:00
committed by GitHub
parent dc13e99348
commit 2148441fd3
5 changed files with 74 additions and 4 deletions

View File

@@ -71,6 +71,19 @@ class RayTPUExecutor(TPUExecutor):
worker_module_name = "vllm.worker.tpu_worker"
worker_class_name = "TPUWorker"
# GKE does not fetch environment information from metadata server
# and instead sets these from within the Ray process. Therefore we
# need to override the Ray environment variables manually.
override_env = {}
if "TPU_CHIPS_PER_HOST_BOUNDS" in os.environ:
override_env.update({
"TPU_CHIPS_PER_HOST_BOUNDS":
os.environ["TPU_CHIPS_PER_HOST_BOUNDS"]
})
if "TPU_HOST_BOUNDS" in os.environ:
override_env.update(
{"TPU_HOST_BOUNDS": os.environ["TPU_HOST_BOUNDS"]})
worker = ray.remote(
num_cpus=0,
resources={"TPU": 1},
@@ -81,6 +94,8 @@ class RayTPUExecutor(TPUExecutor):
worker_class_name=worker_class_name,
trust_remote_code=self.model_config.trust_remote_code,
)
if override_env:
worker.override_env_vars.remote(override_env)
worker_ip = ray.get(worker.get_node_ip.remote())
if worker_ip == driver_ip and self.driver_dummy_worker is None:

View File

@@ -1,3 +1,4 @@
import os
import time
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Union
@@ -84,6 +85,9 @@ try:
return output
def override_env_vars(self, vars: Dict[str, str]):
os.environ.update(vars)
ray_import_err = None
except ImportError as e:
@@ -291,3 +295,28 @@ def initialize_ray_cluster(
_verify_bundles(current_placement_group, parallel_config, device_str)
# Set the placement group in the parallel config
parallel_config.placement_group = current_placement_group
def get_num_tpu_nodes() -> int:
from ray._private.accelerators import TPUAcceleratorManager
cluster_resources = ray.cluster_resources()
total_tpus = int(cluster_resources["TPU"])
tpus_per_node = TPUAcceleratorManager.get_current_node_num_accelerators()
assert total_tpus % tpus_per_node == 0
return total_tpus // tpus_per_node
def get_num_nodes_in_placement_group() -> int:
pg_table = ray.util.placement_group_table()
current_pg = ray.util.get_current_placement_group()
num_nodes = 0
if current_pg:
nodes_in_pg = set()
for pg_key, pg in pg_table.items():
if pg_key == current_pg.id.hex():
for _, node in pg["bundles_to_node_id"].items():
nodes_in_pg.add(node)
num_nodes = len(nodes_in_pg)
return num_nodes