[Core][Distributed] use cpu group to broadcast metadata in cpu (#4444)

This commit is contained in:
youkaichao
2024-04-29 13:52:22 -07:00
committed by GitHub
parent ac5ccf0156
commit f4f921b7f1
3 changed files with 63 additions and 35 deletions

View File

@@ -6,14 +6,14 @@ import uuid
from functools import partial from functools import partial
from typing import Type from typing import Type
import torch
import torch.nn as nn import torch.nn as nn
from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer, from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer,
TensorSerializer, stream_io) TensorSerializer, stream_io)
from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor
from transformers import AutoConfig, PretrainedConfig from transformers import AutoConfig, PretrainedConfig
from vllm.distributed import initialize_model_parallel from vllm.distributed import (init_distributed_environment,
initialize_model_parallel)
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.model_executor.model_loader.tensorizer import TensorizerArgs from vllm.model_executor.model_loader.tensorizer import TensorizerArgs
@@ -226,7 +226,7 @@ model_name = model_ref.split("/")[1]
os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "8080" os.environ["MASTER_PORT"] = "8080"
torch.distributed.init_process_group(world_size=1, rank=0) init_distributed_environment(world_size=1, rank=0, local_rank=0)
initialize_model_parallel() initialize_model_parallel()
keyfile = args.keyfile if args.keyfile else None keyfile = args.keyfile if args.keyfile else None

View File

@@ -2,8 +2,10 @@ import pytest
import torch import torch
from vllm.config import ModelConfig, SchedulerConfig from vllm.config import ModelConfig, SchedulerConfig
from vllm.distributed.parallel_state import init_distributed_environment
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import get_open_port
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
@@ -249,19 +251,18 @@ def test_empty_seq_group():
assert len(return_prompt_lens) == 0 assert len(return_prompt_lens) == 0
@pytest.fixture
def distributed_init():
init_distributed_environment(
world_size=1,
rank=0,
distributed_init_method=f"tcp://127.0.0.1:{get_open_port()}",
local_rank=0)
@pytest.mark.parametrize("batch_size", list(range(2, 128))) @pytest.mark.parametrize("batch_size", list(range(2, 128)))
@pytest.mark.parametrize("enforce_eager", [True, False]) @pytest.mark.parametrize("enforce_eager", [True, False])
def test_hybrid_batches(batch_size, enforce_eager, monkeypatch): def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
def get_world_size(group=None):
return 1
def mock_get_process_group_ranks(group=None):
return [0]
monkeypatch.setattr(torch.distributed, "get_world_size", get_world_size)
monkeypatch.setattr(torch.distributed, "get_process_group_ranks",
mock_get_process_group_ranks)
model_config = ModelConfig( model_config = ModelConfig(
"facebook/opt-125m", "facebook/opt-125m",

View File

@@ -4,7 +4,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union
import torch import torch
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from .parallel_state import (get_tensor_model_parallel_group, from .parallel_state import (get_cpu_world_group,
get_tensor_model_parallel_group,
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
is_pynccl_enabled_for_all_reduce) is_pynccl_enabled_for_all_reduce)
@@ -140,13 +141,46 @@ def broadcast_object_list(obj_list: List[Any],
TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"]) TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"])
def _split_tensor_dict(
tensor_dict: Dict[Any, Union[torch.Tensor, Any]]
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
"""Split the tensor dictionary into two parts:
1. A list of (key, value) pairs. If the value is a tensor, it is replaced
by its metadata.
2. A list of tensors.
"""
metadata_list = []
tensor_list = []
for key, value in tensor_dict.items():
if isinstance(value, torch.Tensor):
# Note(youkaichao): currently this only supports broadcasting
# tensors on cuda. In the future, we can add device as a field in
# TensorMetadata to support broadcasting tensors on different
# devices.
assert value.is_cuda, (
f"Tensor {key}: {value} is not on cuda. Currently we only "
f"support broadcasting tensors on cuda.")
metadata_list.append((key, TensorMetadata(value.dtype,
value.size())))
tensor_list.append(value)
else:
metadata_list.append((key, value))
return metadata_list, tensor_list
def broadcast_tensor_dict( def broadcast_tensor_dict(
tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None,
src: int = 0, src: int = 0,
group: Optional[ProcessGroup] = None, group: Optional[ProcessGroup] = None,
metadata_group: Optional[ProcessGroup] = None
) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: ) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]:
"""Broadcast the input tensor dictionary.""" """Broadcast the input tensor dictionary.
`group` is used to broadcast the tensors, while `metadata_group` is used
to broadcast the metadata of the dict (e.g. dict structure, tensor sizes,
dtypes).
"""
group = group or torch.distributed.group.WORLD group = group or torch.distributed.group.WORLD
metadata_group = metadata_group or get_cpu_world_group()
ranks = torch.distributed.get_process_group_ranks(group) ranks = torch.distributed.get_process_group_ranks(group)
assert src in ranks, f"Invalid src rank ({src})" assert src in ranks, f"Invalid src rank ({src})"
@@ -161,27 +195,20 @@ def broadcast_tensor_dict(
assert isinstance( assert isinstance(
tensor_dict, tensor_dict,
dict), (f"Expecting a dictionary, got {type(tensor_dict)}") dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
for key, value in tensor_dict.items(): metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
if isinstance(value, torch.Tensor): # `metadata_list` lives in CPU memory.
assert value.is_cuda, ( # `broadcast_object_list` involves serialization and deserialization,
f"Tensor {key}: {value} is not on cuda. Currently we only " # all happening on CPU. Therefore, we can use the CPU group.
f"support broadcasting tensors on cuda.")
metadata_list.append(
(key, TensorMetadata(value.dtype, value.size())))
else:
metadata_list.append((key, value))
torch.distributed.broadcast_object_list([metadata_list], torch.distributed.broadcast_object_list([metadata_list],
src=src, src=src,
group=group) group=metadata_group)
async_handles = [] async_handles = []
for key, value in metadata_list: for tensor in tensor_list:
if isinstance(value, TensorMetadata): async_handles.append(
tensor = tensor_dict[key] torch.distributed.broadcast(tensor,
async_handles.append( src=src,
torch.distributed.broadcast(tensor, group=group,
src=src, async_op=True))
group=group,
async_op=True))
for async_handle in async_handles: for async_handle in async_handles:
async_handle.wait() async_handle.wait()
@@ -189,7 +216,7 @@ def broadcast_tensor_dict(
recv_metadata_list = [None] recv_metadata_list = [None]
torch.distributed.broadcast_object_list(recv_metadata_list, torch.distributed.broadcast_object_list(recv_metadata_list,
src=src, src=src,
group=group) group=metadata_group)
assert recv_metadata_list[0] is not None assert recv_metadata_list[0] is not None
tensor_dict = {} tensor_dict = {}
async_handles = [] async_handles = []