Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -10,21 +10,22 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.utils import (cuda_device_count_stateless, get_open_port,
|
||||
update_environment_variables)
|
||||
from vllm.utils import (
|
||||
cuda_device_count_stateless,
|
||||
get_open_port,
|
||||
update_environment_variables,
|
||||
)
|
||||
|
||||
from ..utils import multi_gpu_test
|
||||
|
||||
|
||||
@ray.remote
|
||||
class _CUDADeviceCountStatelessTestActor:
|
||||
|
||||
def get_count(self):
|
||||
return cuda_device_count_stateless()
|
||||
|
||||
def set_cuda_visible_devices(self, cuda_visible_devices: str):
|
||||
update_environment_variables(
|
||||
{"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
||||
update_environment_variables({"CUDA_VISIBLE_DEVICES": cuda_visible_devices})
|
||||
|
||||
def get_cuda_visible_devices(self):
|
||||
return envs.CUDA_VISIBLE_DEVICES
|
||||
@@ -34,10 +35,9 @@ def test_cuda_device_count_stateless():
|
||||
"""Test that cuda_device_count_stateless changes return value if
|
||||
CUDA_VISIBLE_DEVICES is changed."""
|
||||
actor = _CUDADeviceCountStatelessTestActor.options( # type: ignore
|
||||
num_gpus=2).remote()
|
||||
assert len(
|
||||
sorted(ray.get(
|
||||
actor.get_cuda_visible_devices.remote()).split(","))) == 2
|
||||
num_gpus=2
|
||||
).remote()
|
||||
assert len(sorted(ray.get(actor.get_cuda_visible_devices.remote()).split(","))) == 2
|
||||
assert ray.get(actor.get_count.remote()) == 2
|
||||
ray.get(actor.set_cuda_visible_devices.remote("0"))
|
||||
assert ray.get(actor.get_count.remote()) == 1
|
||||
@@ -46,15 +46,13 @@ def test_cuda_device_count_stateless():
|
||||
|
||||
|
||||
def cpu_worker(rank, WORLD_SIZE, port1, port2):
|
||||
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
|
||||
port=port1,
|
||||
rank=rank,
|
||||
world_size=WORLD_SIZE)
|
||||
pg1 = StatelessProcessGroup.create(
|
||||
host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE
|
||||
)
|
||||
if rank <= 2:
|
||||
pg2 = StatelessProcessGroup.create(host="127.0.0.1",
|
||||
port=port2,
|
||||
rank=rank,
|
||||
world_size=3)
|
||||
pg2 = StatelessProcessGroup.create(
|
||||
host="127.0.0.1", port=port2, rank=rank, world_size=3
|
||||
)
|
||||
data = torch.tensor([rank])
|
||||
data = pg1.broadcast_obj(data, src=2)
|
||||
assert data.item() == 2
|
||||
@@ -68,16 +66,14 @@ def cpu_worker(rank, WORLD_SIZE, port1, port2):
|
||||
|
||||
def gpu_worker(rank, WORLD_SIZE, port1, port2):
|
||||
torch.cuda.set_device(rank)
|
||||
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
|
||||
port=port1,
|
||||
rank=rank,
|
||||
world_size=WORLD_SIZE)
|
||||
pg1 = StatelessProcessGroup.create(
|
||||
host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE
|
||||
)
|
||||
pynccl1 = PyNcclCommunicator(pg1, device=rank)
|
||||
if rank <= 2:
|
||||
pg2 = StatelessProcessGroup.create(host="127.0.0.1",
|
||||
port=port2,
|
||||
rank=rank,
|
||||
world_size=3)
|
||||
pg2 = StatelessProcessGroup.create(
|
||||
host="127.0.0.1", port=port2, rank=rank, world_size=3
|
||||
)
|
||||
pynccl2 = PyNcclCommunicator(pg2, device=rank)
|
||||
data = torch.tensor([rank]).cuda()
|
||||
pynccl1.all_reduce(data)
|
||||
@@ -96,10 +92,9 @@ def gpu_worker(rank, WORLD_SIZE, port1, port2):
|
||||
|
||||
|
||||
def broadcast_worker(rank, WORLD_SIZE, port1, port2):
|
||||
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
|
||||
port=port1,
|
||||
rank=rank,
|
||||
world_size=WORLD_SIZE)
|
||||
pg1 = StatelessProcessGroup.create(
|
||||
host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE
|
||||
)
|
||||
if rank == 2:
|
||||
pg1.broadcast_obj("secret", src=2)
|
||||
else:
|
||||
@@ -109,10 +104,9 @@ def broadcast_worker(rank, WORLD_SIZE, port1, port2):
|
||||
|
||||
|
||||
def allgather_worker(rank, WORLD_SIZE, port1, port2):
|
||||
pg1 = StatelessProcessGroup.create(host="127.0.0.1",
|
||||
port=port1,
|
||||
rank=rank,
|
||||
world_size=WORLD_SIZE)
|
||||
pg1 = StatelessProcessGroup.create(
|
||||
host="127.0.0.1", port=port1, rank=rank, world_size=WORLD_SIZE
|
||||
)
|
||||
data = pg1.all_gather_obj(rank)
|
||||
assert data == list(range(WORLD_SIZE))
|
||||
pg1.barrier()
|
||||
@@ -121,7 +115,8 @@ def allgather_worker(rank, WORLD_SIZE, port1, port2):
|
||||
@pytest.mark.skip(reason="This test is flaky and prone to hang.")
|
||||
@multi_gpu_test(num_gpus=4)
|
||||
@pytest.mark.parametrize(
|
||||
"worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker])
|
||||
"worker", [cpu_worker, gpu_worker, broadcast_worker, allgather_worker]
|
||||
)
|
||||
def test_stateless_process_group(worker):
|
||||
port1 = get_open_port()
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
@@ -129,12 +124,14 @@ def test_stateless_process_group(worker):
|
||||
port2 = get_open_port()
|
||||
WORLD_SIZE = 4
|
||||
from multiprocessing import get_context
|
||||
|
||||
ctx = get_context("fork")
|
||||
processes = []
|
||||
for i in range(WORLD_SIZE):
|
||||
rank = i
|
||||
processes.append(
|
||||
ctx.Process(target=worker, args=(rank, WORLD_SIZE, port1, port2)))
|
||||
ctx.Process(target=worker, args=(rank, WORLD_SIZE, port1, port2))
|
||||
)
|
||||
for p in processes:
|
||||
p.start()
|
||||
for p in processes:
|
||||
|
||||
Reference in New Issue
Block a user