Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -6,13 +6,11 @@ import traceback
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.multiprocessing import (
|
||||
spawn) # pyright: ignore[reportPrivateImportUsage]
|
||||
from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage]
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed import (init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.distributed import init_distributed_environment, initialize_model_parallel
|
||||
from vllm.utils import get_open_port
|
||||
|
||||
## Parallel Processes Utils
|
||||
@@ -30,10 +28,11 @@ class ProcessGroupInfo:
|
||||
device: torch.device
|
||||
|
||||
|
||||
def _set_vllm_config(vllm_config: VllmConfig, world_size: int, rank: int,
|
||||
local_rank: int):
|
||||
|
||||
def _set_vllm_config(
|
||||
vllm_config: VllmConfig, world_size: int, rank: int, local_rank: int
|
||||
):
|
||||
import tempfile
|
||||
|
||||
temp_file = tempfile.mkstemp()[1]
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
@@ -46,13 +45,10 @@ def _set_vllm_config(vllm_config: VllmConfig, world_size: int, rank: int,
|
||||
)
|
||||
|
||||
initialize_model_parallel(
|
||||
tensor_model_parallel_size=vllm_config.parallel_config.
|
||||
tensor_parallel_size,
|
||||
pipeline_model_parallel_size=vllm_config.parallel_config.
|
||||
pipeline_parallel_size,
|
||||
tensor_model_parallel_size=vllm_config.parallel_config.tensor_parallel_size,
|
||||
pipeline_model_parallel_size=vllm_config.parallel_config.pipeline_parallel_size,
|
||||
)
|
||||
cpu_group = torch.distributed.new_group(list(range(world_size)),
|
||||
backend="gloo")
|
||||
cpu_group = torch.distributed.new_group(list(range(world_size)), backend="gloo")
|
||||
return cpu_group
|
||||
|
||||
|
||||
@@ -62,8 +58,7 @@ def _worker_parallel_launch(
|
||||
world_local_size: int,
|
||||
node_rank: int,
|
||||
init_method: str,
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, Optional[VllmConfig], Any,
|
||||
P], None],
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, Optional[VllmConfig], Any, P], None],
|
||||
vllm_config: Optional[VllmConfig],
|
||||
env_dict: Optional[dict],
|
||||
*args: P.args,
|
||||
@@ -131,7 +126,8 @@ def parallel_launch_with_config(
|
||||
worker,
|
||||
vllm_config,
|
||||
env_dict,
|
||||
) + args,
|
||||
)
|
||||
+ args,
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user