[platform] add ray_device_key (#11948)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@@ -8,6 +8,7 @@ import msgspec
|
|||||||
from vllm.config import ParallelConfig
|
from vllm.config import ParallelConfig
|
||||||
from vllm.executor.msgspec_utils import decode_hook, encode_hook
|
from vllm.executor.msgspec_utils import decode_hook, encode_hook
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.platforms import current_platform
|
||||||
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
|
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
|
||||||
from vllm.utils import get_ip
|
from vllm.utils import get_ip
|
||||||
from vllm.worker.worker_base import WorkerWrapperBase
|
from vllm.worker.worker_base import WorkerWrapperBase
|
||||||
@@ -47,7 +48,12 @@ try:
|
|||||||
|
|
||||||
def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
|
def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
|
||||||
node_id = ray.get_runtime_context().get_node_id()
|
node_id = ray.get_runtime_context().get_node_id()
|
||||||
gpu_ids = ray.get_gpu_ids()
|
device_key = current_platform.ray_device_key
|
||||||
|
if not device_key:
|
||||||
|
raise RuntimeError("current platform %s does not support ray.",
|
||||||
|
current_platform.device_name)
|
||||||
|
gpu_ids = ray.get_runtime_context().get_accelerator_ids(
|
||||||
|
)[device_key]
|
||||||
return node_id, gpu_ids
|
return node_id, gpu_ids
|
||||||
|
|
||||||
def execute_model_spmd(
|
def execute_model_spmd(
|
||||||
@@ -249,11 +255,12 @@ def initialize_ray_cluster(
|
|||||||
# Placement group is already set.
|
# Placement group is already set.
|
||||||
return
|
return
|
||||||
|
|
||||||
device_str = "GPU"
|
device_str = current_platform.ray_device_key
|
||||||
if current_platform.is_tpu():
|
if not device_str:
|
||||||
device_str = "TPU"
|
raise ValueError(
|
||||||
elif current_platform.is_hpu():
|
f"current platform {current_platform.device_name} does not "
|
||||||
device_str = 'HPU'
|
"support ray.")
|
||||||
|
|
||||||
# Create placement group for worker processes
|
# Create placement group for worker processes
|
||||||
current_placement_group = ray.util.get_current_placement_group()
|
current_placement_group = ray.util.get_current_placement_group()
|
||||||
if current_placement_group:
|
if current_placement_group:
|
||||||
|
|||||||
@@ -77,6 +77,7 @@ class CudaPlatformBase(Platform):
|
|||||||
device_name: str = "cuda"
|
device_name: str = "cuda"
|
||||||
device_type: str = "cuda"
|
device_type: str = "cuda"
|
||||||
dispatch_key: str = "CUDA"
|
dispatch_key: str = "CUDA"
|
||||||
|
ray_device_key: str = "GPU"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_device_capability(cls,
|
def get_device_capability(cls,
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ class HpuPlatform(Platform):
|
|||||||
device_name: str = "hpu"
|
device_name: str = "hpu"
|
||||||
device_type: str = "hpu"
|
device_type: str = "hpu"
|
||||||
dispatch_key: str = "HPU"
|
dispatch_key: str = "HPU"
|
||||||
|
ray_device_key: str = "HPU"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||||
|
|||||||
@@ -82,6 +82,10 @@ class Platform:
|
|||||||
# check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
|
# check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa
|
||||||
# use "CPU" as a fallback for platforms not registered in PyTorch
|
# use "CPU" as a fallback for platforms not registered in PyTorch
|
||||||
dispatch_key: str = "CPU"
|
dispatch_key: str = "CPU"
|
||||||
|
# available ray device keys:
|
||||||
|
# https://github.com/ray-project/ray/blob/10ba5adadcc49c60af2c358a33bb943fb491a171/python/ray/_private/ray_constants.py#L438 # noqa
|
||||||
|
# empty string means the device does not support ray
|
||||||
|
ray_device_key: str = ""
|
||||||
# The torch.compile backend for compiling simple and
|
# The torch.compile backend for compiling simple and
|
||||||
# standalone functions. The default value is "inductor" to keep
|
# standalone functions. The default value is "inductor" to keep
|
||||||
# the same behavior as PyTorch.
|
# the same behavior as PyTorch.
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ class NeuronPlatform(Platform):
|
|||||||
_enum = PlatformEnum.NEURON
|
_enum = PlatformEnum.NEURON
|
||||||
device_name: str = "neuron"
|
device_name: str = "neuron"
|
||||||
device_type: str = "neuron"
|
device_type: str = "neuron"
|
||||||
|
ray_device_key: str = "neuron_cores"
|
||||||
supported_quantization: list[str] = ["neuron_quant"]
|
supported_quantization: list[str] = ["neuron_quant"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -64,6 +64,8 @@ class RocmPlatform(Platform):
|
|||||||
device_name: str = "rocm"
|
device_name: str = "rocm"
|
||||||
device_type: str = "cuda"
|
device_type: str = "cuda"
|
||||||
dispatch_key: str = "CUDA"
|
dispatch_key: str = "CUDA"
|
||||||
|
ray_device_key: str = "GPU"
|
||||||
|
|
||||||
supported_quantization: list[str] = [
|
supported_quantization: list[str] = [
|
||||||
"awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
|
"awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
|
||||||
"fbgemm_fp8", "gguf"
|
"fbgemm_fp8", "gguf"
|
||||||
|
|||||||
@@ -19,6 +19,8 @@ class TpuPlatform(Platform):
|
|||||||
device_name: str = "tpu"
|
device_name: str = "tpu"
|
||||||
device_type: str = "tpu"
|
device_type: str = "tpu"
|
||||||
dispatch_key: str = "XLA"
|
dispatch_key: str = "XLA"
|
||||||
|
ray_device_key: str = "TPU"
|
||||||
|
|
||||||
supported_quantization: list[str] = [
|
supported_quantization: list[str] = [
|
||||||
"tpu_int8", "compressed-tensors", "compressed_tensors"
|
"tpu_int8", "compressed-tensors", "compressed_tensors"
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -19,6 +19,9 @@ class XPUPlatform(Platform):
|
|||||||
device_name: str = "xpu"
|
device_name: str = "xpu"
|
||||||
device_type: str = "xpu"
|
device_type: str = "xpu"
|
||||||
dispatch_key: str = "XPU"
|
dispatch_key: str = "XPU"
|
||||||
|
# Intel XPU's device key is "GPU" for Ray.
|
||||||
|
# see https://github.com/ray-project/ray/blob/6a5eb5865eeb9ccf058a79b44f107e327e360673/python/ray/_private/accelerators/intel_gpu.py#L20 # noqa: E501
|
||||||
|
ray_device_key: str = "GPU"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
|
||||||
|
|||||||
@@ -41,7 +41,12 @@ try:
|
|||||||
|
|
||||||
def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
|
def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]:
|
||||||
node_id = ray.get_runtime_context().get_node_id()
|
node_id = ray.get_runtime_context().get_node_id()
|
||||||
gpu_ids = ray.get_gpu_ids()
|
device_key = current_platform.ray_device_key
|
||||||
|
if not device_key:
|
||||||
|
raise RuntimeError("current platform %s does not support ray.",
|
||||||
|
current_platform.device_name)
|
||||||
|
gpu_ids = ray.get_runtime_context().get_accelerator_ids(
|
||||||
|
)[device_key]
|
||||||
return node_id, gpu_ids
|
return node_id, gpu_ids
|
||||||
|
|
||||||
def setup_device_if_necessary(self):
|
def setup_device_if_necessary(self):
|
||||||
@@ -211,7 +216,11 @@ def initialize_ray_cluster(
|
|||||||
# Placement group is already set.
|
# Placement group is already set.
|
||||||
return
|
return
|
||||||
|
|
||||||
device_str = "GPU" if not current_platform.is_tpu() else "TPU"
|
device_str = current_platform.ray_device_key
|
||||||
|
if not device_str:
|
||||||
|
raise ValueError(
|
||||||
|
f"current platform {current_platform.device_name} does not "
|
||||||
|
"support ray.")
|
||||||
# Create placement group for worker processes
|
# Create placement group for worker processes
|
||||||
current_placement_group = ray.util.get_current_placement_group()
|
current_placement_group = ray.util.get_current_placement_group()
|
||||||
if current_placement_group:
|
if current_placement_group:
|
||||||
|
|||||||
Reference in New Issue
Block a user