Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -11,15 +11,17 @@ import torch.multiprocessing as mp
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.distributed.device_communicators.cuda_communicator import (
|
||||
CudaCommunicator)
|
||||
from vllm.distributed.device_communicators.pynccl import (
|
||||
register_nccl_symmetric_ops)
|
||||
from vllm.distributed.device_communicators.cuda_communicator import CudaCommunicator
|
||||
from vllm.distributed.device_communicators.pynccl import register_nccl_symmetric_ops
|
||||
from vllm.distributed.device_communicators.pynccl_allocator import (
|
||||
get_nccl_mem_pool, is_symmetric_memory_enabled)
|
||||
from vllm.distributed.parallel_state import (get_tp_group,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
get_nccl_mem_pool,
|
||||
is_symmetric_memory_enabled,
|
||||
)
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tp_group,
|
||||
init_distributed_environment,
|
||||
initialize_model_parallel,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import update_environment_variables
|
||||
|
||||
@@ -38,31 +40,32 @@ def nccl_symm_mem_allreduce_worker(local_rank: int, world_size: int):
|
||||
torch.cuda.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
torch.set_default_dtype(dtype)
|
||||
update_environment_variables({
|
||||
"RANK": str(local_rank),
|
||||
"LOCAL_RANK": str(local_rank),
|
||||
"WORLD_SIZE": str(world_size),
|
||||
"MASTER_ADDR": "localhost",
|
||||
"MASTER_PORT": "12345",
|
||||
})
|
||||
update_environment_variables(
|
||||
{
|
||||
"RANK": str(local_rank),
|
||||
"LOCAL_RANK": str(local_rank),
|
||||
"WORLD_SIZE": str(world_size),
|
||||
"MASTER_ADDR": "localhost",
|
||||
"MASTER_PORT": "12345",
|
||||
}
|
||||
)
|
||||
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
cuda_communicator = typing.cast(CudaCommunicator,
|
||||
get_tp_group().device_communicator)
|
||||
cuda_communicator = typing.cast(
|
||||
CudaCommunicator, get_tp_group().device_communicator
|
||||
)
|
||||
pynccl_comm = cuda_communicator.pynccl_comm
|
||||
if get_nccl_mem_pool() is None:
|
||||
pytest.skip("NCCL allocator compilation failed "
|
||||
"(probably missing NCCL headers).")
|
||||
pytest.skip(
|
||||
"NCCL allocator compilation failed (probably missing NCCL headers)."
|
||||
)
|
||||
if not is_symmetric_memory_enabled():
|
||||
pytest.skip("NCCL symmetric memory allreduce is disabled.")
|
||||
|
||||
register_nccl_symmetric_ops(pynccl_comm)
|
||||
input = torch.randint(1,
|
||||
23, (test_size_elements, ),
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
input = torch.randint(1, 23, (test_size_elements,), dtype=dtype, device=device)
|
||||
input_clone = input.clone()
|
||||
output = torch.ops.vllm.all_reduce_symmetric_with_copy(input)
|
||||
assert output is not None
|
||||
@@ -77,8 +80,7 @@ def nccl_symm_mem_allreduce_worker(local_rank: int, world_size: int):
|
||||
reason="NCCLSymmMemAllreduce is only available for CUDA platforms.",
|
||||
)
|
||||
@pytest.mark.parametrize("world_size", [2])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
||||
reason="Only test on CUDA")
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"], reason="Only test on CUDA")
|
||||
def test_nccl_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, world_size):
|
||||
if world_size > torch.cuda.device_count():
|
||||
pytest.skip("Not enough GPUs to run the test.")
|
||||
@@ -88,7 +90,5 @@ def test_nccl_symm_mem_allreduce(monkeypatch: pytest.MonkeyPatch, world_size):
|
||||
monkeypatch.setenv("NCCL_NVLS_ENABLE", "1")
|
||||
monkeypatch.setenv("NCCL_CUMEM_ENABLE", "1")
|
||||
|
||||
mp.spawn(nccl_symm_mem_allreduce_worker,
|
||||
args=(world_size, ),
|
||||
nprocs=world_size)
|
||||
mp.spawn(nccl_symm_mem_allreduce_worker, args=(world_size,), nprocs=world_size)
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
Reference in New Issue
Block a user