[platform] add ray_device_key (#11948)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2025-01-13 16:20:52 +08:00
committed by GitHub
parent c3f05b09a0
commit 89ce62a316
9 changed files with 38 additions and 8 deletions

View File

@@ -41,7 +41,12 @@ try:
def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
node_id = ray.get_runtime_context().get_node_id()
gpu_ids = ray.get_gpu_ids()
device_key = current_platform.ray_device_key
if not device_key:
raise RuntimeError("current platform %s does not support ray.",
current_platform.device_name)
gpu_ids = ray.get_runtime_context().get_accelerator_ids(
)[device_key]
return node_id, gpu_ids
def setup_device_if_necessary(self):
@@ -211,7 +216,11 @@ def initialize_ray_cluster(
# Placement group is already set.
return
device_str = "GPU" if not current_platform.is_tpu() else "TPU"
device_str = current_platform.ray_device_key
if not device_str:
raise ValueError(
f"current platform {current_platform.device_name} does not "
"support ray.")
# Create placement group for worker processes
current_placement_group = ray.util.get_current_placement_group()
if current_placement_group: