[Core][Distributed] use cpu group to broadcast metadata in cpu (#4444)
This commit is contained in:
@@ -6,14 +6,14 @@ import uuid
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer,
|
from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer,
|
||||||
TensorSerializer, stream_io)
|
TensorSerializer, stream_io)
|
||||||
from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor
|
from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor
|
||||||
from transformers import AutoConfig, PretrainedConfig
|
from transformers import AutoConfig, PretrainedConfig
|
||||||
|
|
||||||
from vllm.distributed import initialize_model_parallel
|
from vllm.distributed import (init_distributed_environment,
|
||||||
|
initialize_model_parallel)
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
from vllm.engine.arg_utils import EngineArgs
|
||||||
from vllm.engine.llm_engine import LLMEngine
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
from vllm.model_executor.model_loader.tensorizer import TensorizerArgs
|
from vllm.model_executor.model_loader.tensorizer import TensorizerArgs
|
||||||
@@ -226,7 +226,7 @@ model_name = model_ref.split("/")[1]
|
|||||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||||
os.environ["MASTER_PORT"] = "8080"
|
os.environ["MASTER_PORT"] = "8080"
|
||||||
|
|
||||||
torch.distributed.init_process_group(world_size=1, rank=0)
|
init_distributed_environment(world_size=1, rank=0, local_rank=0)
|
||||||
initialize_model_parallel()
|
initialize_model_parallel()
|
||||||
|
|
||||||
keyfile = args.keyfile if args.keyfile else None
|
keyfile = args.keyfile if args.keyfile else None
|
||||||
|
|||||||
@@ -2,8 +2,10 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.config import ModelConfig, SchedulerConfig
|
from vllm.config import ModelConfig, SchedulerConfig
|
||||||
|
from vllm.distributed.parallel_state import init_distributed_environment
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||||
|
from vllm.utils import get_open_port
|
||||||
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
|
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
|
||||||
|
|
||||||
|
|
||||||
@@ -249,19 +251,18 @@ def test_empty_seq_group():
|
|||||||
assert len(return_prompt_lens) == 0
|
assert len(return_prompt_lens) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def distributed_init():
|
||||||
|
init_distributed_environment(
|
||||||
|
world_size=1,
|
||||||
|
rank=0,
|
||||||
|
distributed_init_method=f"tcp://127.0.0.1:{get_open_port()}",
|
||||||
|
local_rank=0)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("batch_size", list(range(2, 128)))
|
@pytest.mark.parametrize("batch_size", list(range(2, 128)))
|
||||||
@pytest.mark.parametrize("enforce_eager", [True, False])
|
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||||
def test_hybrid_batches(batch_size, enforce_eager, monkeypatch):
|
def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
||||||
|
|
||||||
def get_world_size(group=None):
|
|
||||||
return 1
|
|
||||||
|
|
||||||
def mock_get_process_group_ranks(group=None):
|
|
||||||
return [0]
|
|
||||||
|
|
||||||
monkeypatch.setattr(torch.distributed, "get_world_size", get_world_size)
|
|
||||||
monkeypatch.setattr(torch.distributed, "get_process_group_ranks",
|
|
||||||
mock_get_process_group_ranks)
|
|
||||||
|
|
||||||
model_config = ModelConfig(
|
model_config = ModelConfig(
|
||||||
"facebook/opt-125m",
|
"facebook/opt-125m",
|
||||||
|
|||||||
@@ -4,7 +4,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||||||
import torch
|
import torch
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
|
||||||
from .parallel_state import (get_tensor_model_parallel_group,
|
from .parallel_state import (get_cpu_world_group,
|
||||||
|
get_tensor_model_parallel_group,
|
||||||
get_tensor_model_parallel_rank,
|
get_tensor_model_parallel_rank,
|
||||||
get_tensor_model_parallel_world_size,
|
get_tensor_model_parallel_world_size,
|
||||||
is_pynccl_enabled_for_all_reduce)
|
is_pynccl_enabled_for_all_reduce)
|
||||||
@@ -140,13 +141,46 @@ def broadcast_object_list(obj_list: List[Any],
|
|||||||
TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"])
|
TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"])
|
||||||
|
|
||||||
|
|
||||||
|
def _split_tensor_dict(
|
||||||
|
tensor_dict: Dict[Any, Union[torch.Tensor, Any]]
|
||||||
|
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
|
||||||
|
"""Split the tensor dictionary into two parts:
|
||||||
|
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
|
||||||
|
by its metadata.
|
||||||
|
2. A list of tensors.
|
||||||
|
"""
|
||||||
|
metadata_list = []
|
||||||
|
tensor_list = []
|
||||||
|
for key, value in tensor_dict.items():
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
# Note(youkaichao): currently this only supports broadcasting
|
||||||
|
# tensors on cuda. In the future, we can add device as a field in
|
||||||
|
# TensorMetadata to support broadcasting tensors on different
|
||||||
|
# devices.
|
||||||
|
assert value.is_cuda, (
|
||||||
|
f"Tensor {key}: {value} is not on cuda. Currently we only "
|
||||||
|
f"support broadcasting tensors on cuda.")
|
||||||
|
metadata_list.append((key, TensorMetadata(value.dtype,
|
||||||
|
value.size())))
|
||||||
|
tensor_list.append(value)
|
||||||
|
else:
|
||||||
|
metadata_list.append((key, value))
|
||||||
|
return metadata_list, tensor_list
|
||||||
|
|
||||||
|
|
||||||
def broadcast_tensor_dict(
|
def broadcast_tensor_dict(
|
||||||
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
|
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
|
||||||
src: int = 0,
|
src: int = 0,
|
||||||
group: Optional[ProcessGroup] = None,
|
group: Optional[ProcessGroup] = None,
|
||||||
|
metadata_group: Optional[ProcessGroup] = None
|
||||||
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
|
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
|
||||||
"""Broadcast the input tensor dictionary."""
|
"""Broadcast the input tensor dictionary.
|
||||||
|
`group` is used to broadcast the tensors, while `metadata_group` is used
|
||||||
|
to broadcast the metadata of the dict (e.g. dict structure, tensor sizes,
|
||||||
|
dtypes).
|
||||||
|
"""
|
||||||
group = group or torch.distributed.group.WORLD
|
group = group or torch.distributed.group.WORLD
|
||||||
|
metadata_group = metadata_group or get_cpu_world_group()
|
||||||
ranks = torch.distributed.get_process_group_ranks(group)
|
ranks = torch.distributed.get_process_group_ranks(group)
|
||||||
assert src in ranks, f"Invalid src rank ({src})"
|
assert src in ranks, f"Invalid src rank ({src})"
|
||||||
|
|
||||||
@@ -161,27 +195,20 @@ def broadcast_tensor_dict(
|
|||||||
assert isinstance(
|
assert isinstance(
|
||||||
tensor_dict,
|
tensor_dict,
|
||||||
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
|
dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
|
||||||
for key, value in tensor_dict.items():
|
metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
|
||||||
if isinstance(value, torch.Tensor):
|
# `metadata_list` lives in CPU memory.
|
||||||
assert value.is_cuda, (
|
# `broadcast_object_list` involves serialization and deserialization,
|
||||||
f"Tensor {key}: {value} is not on cuda. Currently we only "
|
# all happening on CPU. Therefore, we can use the CPU group.
|
||||||
f"support broadcasting tensors on cuda.")
|
|
||||||
metadata_list.append(
|
|
||||||
(key, TensorMetadata(value.dtype, value.size())))
|
|
||||||
else:
|
|
||||||
metadata_list.append((key, value))
|
|
||||||
torch.distributed.broadcast_object_list([metadata_list],
|
torch.distributed.broadcast_object_list([metadata_list],
|
||||||
src=src,
|
src=src,
|
||||||
group=group)
|
group=metadata_group)
|
||||||
async_handles = []
|
async_handles = []
|
||||||
for key, value in metadata_list:
|
for tensor in tensor_list:
|
||||||
if isinstance(value, TensorMetadata):
|
async_handles.append(
|
||||||
tensor = tensor_dict[key]
|
torch.distributed.broadcast(tensor,
|
||||||
async_handles.append(
|
src=src,
|
||||||
torch.distributed.broadcast(tensor,
|
group=group,
|
||||||
src=src,
|
async_op=True))
|
||||||
group=group,
|
|
||||||
async_op=True))
|
|
||||||
for async_handle in async_handles:
|
for async_handle in async_handles:
|
||||||
async_handle.wait()
|
async_handle.wait()
|
||||||
|
|
||||||
@@ -189,7 +216,7 @@ def broadcast_tensor_dict(
|
|||||||
recv_metadata_list = [None]
|
recv_metadata_list = [None]
|
||||||
torch.distributed.broadcast_object_list(recv_metadata_list,
|
torch.distributed.broadcast_object_list(recv_metadata_list,
|
||||||
src=src,
|
src=src,
|
||||||
group=group)
|
group=metadata_group)
|
||||||
assert recv_metadata_list[0] is not None
|
assert recv_metadata_list[0] is not None
|
||||||
tensor_dict = {}
|
tensor_dict = {}
|
||||||
async_handles = []
|
async_handles = []
|
||||||
|
|||||||
Reference in New Issue
Block a user