Refactor system architecture (#82)
This commit is contained in:
593
cacheflow/model_executor/parallel_utils/parallel_state.py
Normal file
593
cacheflow/model_executor/parallel_utils/parallel_state.py
Normal file
@@ -0,0 +1,593 @@
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
"""Model and data parallel groups."""
|
||||
|
||||
import torch
|
||||
from typing import Optional
|
||||
|
||||
from .utils import GlobalMemoryBuffer
|
||||
|
||||
# Intra-layer model parallel group that the current rank belongs to.
|
||||
_TENSOR_MODEL_PARALLEL_GROUP = None
|
||||
# Inter-layer model parallel group that the current rank belongs to.
|
||||
_PIPELINE_MODEL_PARALLEL_GROUP = None
|
||||
# Model parallel group (both intra- and pipeline) that the current rank belongs to.
|
||||
_MODEL_PARALLEL_GROUP = None
|
||||
# Embedding group.
|
||||
_EMBEDDING_GROUP = None
|
||||
# Position embedding group.
|
||||
_POSITION_EMBEDDING_GROUP = None
|
||||
# Data parallel group that the current rank belongs to.
|
||||
_DATA_PARALLEL_GROUP = None
|
||||
|
||||
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
|
||||
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
|
||||
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None
|
||||
|
||||
# These values enable us to change the mpu sizes on the fly.
|
||||
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
|
||||
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
|
||||
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
|
||||
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
|
||||
|
||||
# A list of ranks that have a copy of the embedding.
|
||||
_EMBEDDING_GLOBAL_RANKS = None
|
||||
|
||||
# A list of ranks that have a copy of the position embedding.
|
||||
_POSITION_EMBEDDING_GLOBAL_RANKS = None
|
||||
|
||||
# A list of global ranks for each pipeline group to ease calculation of the source
|
||||
# rank when broadcasting from the first or last pipeline stage.
|
||||
_PIPELINE_GLOBAL_RANKS = None
|
||||
|
||||
# A list of global ranks for each data parallel group to ease calculation of the source
|
||||
# rank when broadcasting weights from src to all other data parallel ranks
|
||||
_DATA_PARALLEL_GLOBAL_RANKS = None
|
||||
|
||||
# Memory buffers to avoid dynamic memory allocation
|
||||
_GLOBAL_MEMORY_BUFFER = None
|
||||
|
||||
_ALL_REDUCE_LAUNCHER: Optional['GraphAllReduce'] = None
|
||||
|
||||
def initialize_model_parallel(
|
||||
tensor_model_parallel_size: int = 1,
|
||||
pipeline_model_parallel_size: int = 1,
|
||||
virtual_pipeline_model_parallel_size: Optional[int] = None,
|
||||
pipeline_model_parallel_split_rank: Optional[int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize model data parallel groups.
|
||||
|
||||
Arguments:
|
||||
tensor_model_parallel_size: number of GPUs used for tensor model parallelism.
|
||||
pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism.
|
||||
virtual_pipeline_model_parallel_size: number of virtual stages (interleaved
|
||||
pipeline).
|
||||
pipeline_model_parallel_split_rank: for models with both encoder and decoder,
|
||||
rank in pipeline with split point.
|
||||
|
||||
Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
|
||||
use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
|
||||
the model pipeline. The present function will
|
||||
create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
|
||||
and 8 data-parallel groups as:
|
||||
8 data_parallel groups:
|
||||
[g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
|
||||
8 tensor model-parallel groups:
|
||||
[g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
|
||||
4 pipeline model-parallel groups:
|
||||
[g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
|
||||
Note that for efficiency, the caller should make sure adjacent ranks
|
||||
are on the same DGX box. For example if we are using 2 DGX-1 boxes
|
||||
with a total of 16 GPUs, rank 0 to 7 belong to the first box and
|
||||
ranks 8 to 15 belong to the second box.
|
||||
"""
|
||||
# Get world size and rank. Ensure some consistencies.
|
||||
assert torch.distributed.is_initialized()
|
||||
world_size: int = torch.distributed.get_world_size()
|
||||
|
||||
if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0:
|
||||
raise RuntimeError(
|
||||
f"world_size ({world_size}) is not divisible by tensor_model_parallel_size "
|
||||
f"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})"
|
||||
)
|
||||
|
||||
data_parallel_size: int = world_size // (tensor_model_parallel_size *
|
||||
pipeline_model_parallel_size)
|
||||
|
||||
num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size
|
||||
num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
|
||||
num_data_parallel_groups: int = world_size // data_parallel_size
|
||||
|
||||
if virtual_pipeline_model_parallel_size is not None:
|
||||
if not pipeline_model_parallel_size > 2:
|
||||
raise RuntimeError("pipeline-model-parallel size should be greater than 2 with "
|
||||
"interleaved schedule")
|
||||
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
|
||||
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
||||
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
|
||||
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size
|
||||
|
||||
if pipeline_model_parallel_split_rank is not None:
|
||||
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
|
||||
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
|
||||
# Build the data-parallel groups.
|
||||
global _DATA_PARALLEL_GROUP
|
||||
global _DATA_PARALLEL_GLOBAL_RANKS
|
||||
assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized'
|
||||
all_data_parallel_group_ranks = []
|
||||
for i in range(pipeline_model_parallel_size):
|
||||
start_rank = i * num_pipeline_model_parallel_groups
|
||||
end_rank = (i + 1) * num_pipeline_model_parallel_groups
|
||||
for j in range(tensor_model_parallel_size):
|
||||
ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)
|
||||
all_data_parallel_group_ranks.append(list(ranks))
|
||||
group = torch.distributed.new_group(ranks)
|
||||
if rank in ranks:
|
||||
_DATA_PARALLEL_GROUP = group
|
||||
_DATA_PARALLEL_GLOBAL_RANKS = ranks
|
||||
|
||||
# Build the model-parallel groups.
|
||||
global _MODEL_PARALLEL_GROUP
|
||||
assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized'
|
||||
for i in range(data_parallel_size):
|
||||
ranks = [data_parallel_group_ranks[i]
|
||||
for data_parallel_group_ranks in all_data_parallel_group_ranks]
|
||||
group = torch.distributed.new_group(ranks)
|
||||
if rank in ranks:
|
||||
_MODEL_PARALLEL_GROUP = group
|
||||
|
||||
# Build the tensor model-parallel groups.
|
||||
global _TENSOR_MODEL_PARALLEL_GROUP
|
||||
assert _TENSOR_MODEL_PARALLEL_GROUP is None, \
|
||||
'tensor model parallel group is already initialized'
|
||||
for i in range(num_tensor_model_parallel_groups):
|
||||
ranks = range(i * tensor_model_parallel_size,
|
||||
(i + 1) * tensor_model_parallel_size)
|
||||
group = torch.distributed.new_group(ranks)
|
||||
if rank in ranks:
|
||||
_TENSOR_MODEL_PARALLEL_GROUP = group
|
||||
|
||||
# Build the pipeline model-parallel groups and embedding groups
|
||||
# (first and last rank in each pipeline model-parallel group).
|
||||
global _PIPELINE_MODEL_PARALLEL_GROUP
|
||||
global _PIPELINE_GLOBAL_RANKS
|
||||
assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \
|
||||
'pipeline model parallel group is already initialized'
|
||||
global _EMBEDDING_GROUP
|
||||
global _EMBEDDING_GLOBAL_RANKS
|
||||
assert _EMBEDDING_GROUP is None, 'embedding group is already initialized'
|
||||
global _POSITION_EMBEDDING_GROUP
|
||||
global _POSITION_EMBEDDING_GLOBAL_RANKS
|
||||
assert _POSITION_EMBEDDING_GROUP is None, \
|
||||
'position embedding group is already initialized'
|
||||
for i in range(num_pipeline_model_parallel_groups):
|
||||
ranks = range(i, world_size, num_pipeline_model_parallel_groups)
|
||||
group = torch.distributed.new_group(ranks)
|
||||
if rank in ranks:
|
||||
_PIPELINE_MODEL_PARALLEL_GROUP = group
|
||||
_PIPELINE_GLOBAL_RANKS = ranks
|
||||
# Setup embedding group (to exchange gradients between
|
||||
# first and last stages).
|
||||
if len(ranks) > 1:
|
||||
embedding_ranks = [ranks[0], ranks[-1]]
|
||||
position_embedding_ranks = [ranks[0]]
|
||||
if pipeline_model_parallel_split_rank is not None:
|
||||
if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks:
|
||||
embedding_ranks = [ranks[0],
|
||||
ranks[pipeline_model_parallel_split_rank],
|
||||
ranks[-1]]
|
||||
if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks:
|
||||
position_embedding_ranks = [ranks[0],
|
||||
ranks[pipeline_model_parallel_split_rank]]
|
||||
else:
|
||||
embedding_ranks = ranks
|
||||
position_embedding_ranks = ranks
|
||||
|
||||
group = torch.distributed.new_group(embedding_ranks)
|
||||
if rank in embedding_ranks:
|
||||
_EMBEDDING_GROUP = group
|
||||
if rank in ranks:
|
||||
_EMBEDDING_GLOBAL_RANKS = embedding_ranks
|
||||
|
||||
group = torch.distributed.new_group(position_embedding_ranks)
|
||||
if rank in position_embedding_ranks:
|
||||
_POSITION_EMBEDDING_GROUP = group
|
||||
if rank in ranks:
|
||||
_POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
|
||||
|
||||
# Initialize global memory buffer
|
||||
# This isn't really "parallel state" but there isn't another good place to
|
||||
# put this. If we end up with a more generic initialization of megatron-core
|
||||
# we could stick it there
|
||||
_set_global_memory_buffer()
|
||||
|
||||
|
||||
def initialize_all_reduce_launcher(
|
||||
max_num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
disable_graph: bool = False,
|
||||
) -> None:
|
||||
global _ALL_REDUCE_LAUNCHER
|
||||
_ALL_REDUCE_LAUNCHER = GraphAllReduce(
|
||||
max_num_tokens=max_num_tokens,
|
||||
hidden_size=hidden_size,
|
||||
dtype=dtype,
|
||||
disable_graph=disable_graph,
|
||||
)
|
||||
|
||||
def model_parallel_is_initialized():
|
||||
"""Check if model and data parallel groups are initialized."""
|
||||
if _TENSOR_MODEL_PARALLEL_GROUP is None or \
|
||||
_PIPELINE_MODEL_PARALLEL_GROUP is None or \
|
||||
_DATA_PARALLEL_GROUP is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_model_parallel_group():
|
||||
"""Get the model parallel group the caller rank belongs to."""
|
||||
assert _MODEL_PARALLEL_GROUP is not None, \
|
||||
'model parallel group is not initialized'
|
||||
return _MODEL_PARALLEL_GROUP
|
||||
|
||||
|
||||
def get_tensor_model_parallel_group():
|
||||
"""Get the tensor model parallel group the caller rank belongs to."""
|
||||
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, \
|
||||
'intra_layer_model parallel group is not initialized'
|
||||
return _TENSOR_MODEL_PARALLEL_GROUP
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_group():
|
||||
"""Get the pipeline model parallel group the caller rank belongs to."""
|
||||
assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, \
|
||||
'pipeline_model parallel group is not initialized'
|
||||
return _PIPELINE_MODEL_PARALLEL_GROUP
|
||||
|
||||
|
||||
def get_data_parallel_group():
|
||||
"""Get the data parallel group the caller rank belongs to."""
|
||||
assert _DATA_PARALLEL_GROUP is not None, \
|
||||
'data parallel group is not initialized'
|
||||
return _DATA_PARALLEL_GROUP
|
||||
|
||||
|
||||
def get_embedding_group():
|
||||
"""Get the embedding group the caller rank belongs to."""
|
||||
assert _EMBEDDING_GROUP is not None, \
|
||||
'embedding group is not initialized'
|
||||
return _EMBEDDING_GROUP
|
||||
|
||||
|
||||
def get_position_embedding_group():
|
||||
"""Get the position embedding group the caller rank belongs to."""
|
||||
assert _POSITION_EMBEDDING_GROUP is not None, \
|
||||
'position embedding group is not initialized'
|
||||
return _POSITION_EMBEDDING_GROUP
|
||||
|
||||
|
||||
def set_tensor_model_parallel_world_size(world_size):
|
||||
"""Set the tensor model parallel size"""
|
||||
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
|
||||
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size
|
||||
|
||||
|
||||
def set_pipeline_model_parallel_world_size(world_size):
|
||||
"""Set the pipeline model parallel size"""
|
||||
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
||||
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
|
||||
|
||||
|
||||
def get_tensor_model_parallel_world_size():
|
||||
"""Return world size for the tensor model parallel group."""
|
||||
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
|
||||
if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
|
||||
return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
|
||||
return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_world_size():
|
||||
"""Return world size for the pipeline model parallel group."""
|
||||
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
||||
if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None:
|
||||
return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
||||
return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group())
|
||||
|
||||
|
||||
def set_tensor_model_parallel_rank(rank):
|
||||
"""Set tensor model parallel rank."""
|
||||
global _MPU_TENSOR_MODEL_PARALLEL_RANK
|
||||
_MPU_TENSOR_MODEL_PARALLEL_RANK = rank
|
||||
|
||||
|
||||
def set_pipeline_model_parallel_rank(rank):
|
||||
"""Set pipeline model parallel rank."""
|
||||
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
|
||||
_MPU_PIPELINE_MODEL_PARALLEL_RANK = rank
|
||||
|
||||
|
||||
def set_pipeline_model_parallel_split_rank(rank):
|
||||
"""Set pipeline model parallel split rank."""
|
||||
global _MPU_PIPELINE_MODEL_PARALLEL_SPLIT_RANK
|
||||
_MPU_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = rank
|
||||
|
||||
|
||||
def get_tensor_model_parallel_rank():
|
||||
"""Return my rank for the tensor model parallel group."""
|
||||
global _MPU_TENSOR_MODEL_PARALLEL_RANK
|
||||
if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None:
|
||||
return _MPU_TENSOR_MODEL_PARALLEL_RANK
|
||||
return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_rank():
|
||||
"""Return my rank for the pipeline model parallel group."""
|
||||
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
|
||||
if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None:
|
||||
return _MPU_PIPELINE_MODEL_PARALLEL_RANK
|
||||
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
|
||||
|
||||
|
||||
|
||||
def is_pipeline_first_stage(ignore_virtual=False):
|
||||
"""Return True if in the first pipeline model-parallel stage, False otherwise."""
|
||||
if not ignore_virtual:
|
||||
if get_virtual_pipeline_model_parallel_world_size() is not None and \
|
||||
get_virtual_pipeline_model_parallel_rank() != 0:
|
||||
return False
|
||||
return get_pipeline_model_parallel_rank() == 0
|
||||
|
||||
|
||||
def is_pipeline_last_stage(ignore_virtual=False):
|
||||
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
|
||||
if not ignore_virtual:
|
||||
virtual_pipeline_model_parallel_world_size = \
|
||||
get_virtual_pipeline_model_parallel_world_size()
|
||||
if virtual_pipeline_model_parallel_world_size is not None and \
|
||||
get_virtual_pipeline_model_parallel_rank() != (
|
||||
virtual_pipeline_model_parallel_world_size - 1):
|
||||
return False
|
||||
return get_pipeline_model_parallel_rank() == (
|
||||
get_pipeline_model_parallel_world_size() - 1)
|
||||
|
||||
|
||||
def is_rank_in_embedding_group(ignore_virtual=False):
|
||||
"""Return true if current rank is in embedding group, False otherwise."""
|
||||
rank = torch.distributed.get_rank()
|
||||
global _EMBEDDING_GLOBAL_RANKS
|
||||
if ignore_virtual:
|
||||
return rank in _EMBEDDING_GLOBAL_RANKS
|
||||
if rank in _EMBEDDING_GLOBAL_RANKS:
|
||||
if rank == _EMBEDDING_GLOBAL_RANKS[0]:
|
||||
return is_pipeline_first_stage(ignore_virtual=False)
|
||||
elif rank == _EMBEDDING_GLOBAL_RANKS[-1]:
|
||||
return is_pipeline_last_stage(ignore_virtual=False)
|
||||
else:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_rank_in_position_embedding_group():
|
||||
"""Return true if current rank is in position embedding group, False otherwise."""
|
||||
rank = torch.distributed.get_rank()
|
||||
global _POSITION_EMBEDDING_GLOBAL_RANKS
|
||||
return rank in _POSITION_EMBEDDING_GLOBAL_RANKS
|
||||
|
||||
|
||||
def is_pipeline_stage_before_split(rank=None):
|
||||
"""Return True if pipeline stage executes encoder block for a model
|
||||
with both encoder and decoder."""
|
||||
if get_pipeline_model_parallel_world_size() == 1:
|
||||
return True
|
||||
if rank is None:
|
||||
rank = get_pipeline_model_parallel_rank()
|
||||
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
|
||||
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
|
||||
return True
|
||||
if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_pipeline_stage_after_split(rank=None):
|
||||
"""Return True if pipeline stage executes decoder block for a model
|
||||
with both encoder and decoder."""
|
||||
if get_pipeline_model_parallel_world_size() == 1:
|
||||
return True
|
||||
if rank is None:
|
||||
rank = get_pipeline_model_parallel_rank()
|
||||
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
|
||||
if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
|
||||
return True
|
||||
if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_pipeline_stage_at_split():
|
||||
"""Return true if pipeline stage executes decoder block and next
|
||||
stage executes encoder block for a model with both encoder and
|
||||
decoder."""
|
||||
rank = get_pipeline_model_parallel_rank()
|
||||
return is_pipeline_stage_before_split(rank) and \
|
||||
is_pipeline_stage_after_split(rank+1)
|
||||
|
||||
|
||||
def get_virtual_pipeline_model_parallel_rank():
|
||||
"""Return the virtual pipeline-parallel rank."""
|
||||
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
|
||||
return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
|
||||
|
||||
|
||||
def set_virtual_pipeline_model_parallel_rank(rank):
|
||||
"""Set the virtual pipeline-parallel rank."""
|
||||
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
|
||||
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank
|
||||
|
||||
|
||||
def get_virtual_pipeline_model_parallel_world_size():
|
||||
"""Return the virtual pipeline-parallel world size."""
|
||||
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
||||
return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
||||
|
||||
|
||||
def get_tensor_model_parallel_src_rank():
|
||||
"""Calculate the global rank corresponding to the first local rank
|
||||
in the tensor model parallel group."""
|
||||
global_rank = torch.distributed.get_rank()
|
||||
local_world_size = get_tensor_model_parallel_world_size()
|
||||
return (global_rank // local_world_size) * local_world_size
|
||||
|
||||
|
||||
def get_data_parallel_src_rank():
|
||||
"""Calculate the global rank corresponding to the first local rank
|
||||
in the data parallel group."""
|
||||
assert _DATA_PARALLEL_GLOBAL_RANKS is not None, \
|
||||
"Data parallel group is not initialized"
|
||||
return _DATA_PARALLEL_GLOBAL_RANKS[0]
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_first_rank():
|
||||
"""Return the global rank of the first process in the pipeline for the
|
||||
current tensor parallel group"""
|
||||
assert _PIPELINE_GLOBAL_RANKS is not None, \
|
||||
"Pipeline parallel group is not initialized"
|
||||
return _PIPELINE_GLOBAL_RANKS[0]
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_last_rank():
|
||||
"""Return the global rank of the last process in the pipeline for the
|
||||
current tensor parallel group"""
|
||||
assert _PIPELINE_GLOBAL_RANKS is not None, \
|
||||
"Pipeline parallel group is not initialized"
|
||||
last_rank_local = get_pipeline_model_parallel_world_size() - 1
|
||||
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
|
||||
|
||||
def get_pipeline_model_parallel_next_rank():
|
||||
"""Return the global rank that follows the caller in the pipeline"""
|
||||
assert _PIPELINE_GLOBAL_RANKS is not None, \
|
||||
"Pipeline parallel group is not initialized"
|
||||
rank_in_pipeline = get_pipeline_model_parallel_rank()
|
||||
world_size = get_pipeline_model_parallel_world_size()
|
||||
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
|
||||
|
||||
|
||||
def get_pipeline_model_parallel_prev_rank():
|
||||
"""Return the global rank that preceeds the caller in the pipeline"""
|
||||
assert _PIPELINE_GLOBAL_RANKS is not None, \
|
||||
"Pipeline parallel group is not initialized"
|
||||
rank_in_pipeline = get_pipeline_model_parallel_rank()
|
||||
world_size = get_pipeline_model_parallel_world_size()
|
||||
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
|
||||
|
||||
|
||||
def get_data_parallel_world_size():
|
||||
"""Return world size for the data parallel group."""
|
||||
return torch.distributed.get_world_size(group=get_data_parallel_group())
|
||||
|
||||
|
||||
def get_data_parallel_rank():
|
||||
"""Return my rank for the data parallel group."""
|
||||
return torch.distributed.get_rank(group=get_data_parallel_group())
|
||||
|
||||
def _set_global_memory_buffer():
|
||||
"""Initialize global buffer"""
|
||||
global _GLOBAL_MEMORY_BUFFER
|
||||
assert _GLOBAL_MEMORY_BUFFER is None, 'global memory buffer is already initialized'
|
||||
_GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer()
|
||||
|
||||
def get_global_memory_buffer():
|
||||
"""Return the global GlobalMemoryBuffer object"""
|
||||
assert _GLOBAL_MEMORY_BUFFER is not None, 'global memory buffer is not initialized'
|
||||
return _GLOBAL_MEMORY_BUFFER
|
||||
|
||||
def get_all_reduce_launcher() -> 'GraphAllReduce':
|
||||
assert _ALL_REDUCE_LAUNCHER is not None, 'all reduce launcher is not initialized'
|
||||
return _ALL_REDUCE_LAUNCHER
|
||||
|
||||
def destroy_model_parallel():
|
||||
"""Set the groups to none."""
|
||||
global _MODEL_PARALLEL_GROUP
|
||||
_MODEL_PARALLEL_GROUP = None
|
||||
global _TENSOR_MODEL_PARALLEL_GROUP
|
||||
_TENSOR_MODEL_PARALLEL_GROUP = None
|
||||
global _PIPELINE_MODEL_PARALLEL_GROUP
|
||||
_PIPELINE_MODEL_PARALLEL_GROUP = None
|
||||
global _DATA_PARALLEL_GROUP
|
||||
_DATA_PARALLEL_GROUP = None
|
||||
global _EMBEDDING_GROUP
|
||||
_EMBEDDING_GROUP = None
|
||||
global _POSITION_EMBEDDING_GROUP
|
||||
_POSITION_EMBEDDING_GROUP = None
|
||||
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
|
||||
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
|
||||
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
||||
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
|
||||
global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
|
||||
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
|
||||
global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
|
||||
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
|
||||
global _MPU_TENSOR_MODEL_PARALLEL_RANK
|
||||
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
|
||||
global _MPU_PIPELINE_MODEL_PARALLEL_RANK
|
||||
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
|
||||
global _GLOBAL_MEMORY_BUFFER
|
||||
_GLOBAL_MEMORY_BUFFER = None
|
||||
|
||||
|
||||
class GraphAllReduce:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_tokens: int,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
disable_graph: bool = False,
|
||||
) -> None:
|
||||
self.max_num_tokens = max_num_tokens
|
||||
self.hidden_size = hidden_size
|
||||
self.disable_graph = disable_graph
|
||||
|
||||
tp_world_size = get_tensor_model_parallel_world_size()
|
||||
if tp_world_size == 1:
|
||||
return
|
||||
|
||||
self.group = get_tensor_model_parallel_group()
|
||||
self.buffer = torch.empty(
|
||||
size=(max_num_tokens, hidden_size),
|
||||
dtype=dtype,
|
||||
device='cuda',
|
||||
)
|
||||
|
||||
# Build graphs for different number of tokens.
|
||||
if not self.disable_graph:
|
||||
self.graphs = {}
|
||||
for num_tokens in range(8, max_num_tokens + 1, 8):
|
||||
self.graphs[num_tokens] = self._build_graph(num_tokens)
|
||||
|
||||
def _build_graph(self, num_tokens: int) -> torch.cuda.CUDAGraph:
|
||||
# Warm up.
|
||||
torch.distributed.all_reduce(self.buffer[:num_tokens], group=self.group)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# Build graph.
|
||||
graph = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(graph):
|
||||
torch.distributed.all_reduce(
|
||||
self.buffer[:num_tokens], group=self.group)
|
||||
torch.cuda.synchronize()
|
||||
return graph
|
||||
|
||||
def launch(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# NOTE: x must be a slice of self.buffer.
|
||||
num_tokens = x.shape[0]
|
||||
if self.disable_graph:
|
||||
torch.distributed.all_reduce(x, group=self.group)
|
||||
else:
|
||||
self.graphs[num_tokens].replay()
|
||||
return x
|
||||
Reference in New Issue
Block a user