[core][distributed] simplify code to support pipeline parallel (#6406)

This commit is contained in:
youkaichao
2024-07-14 21:20:51 -07:00
committed by GitHub
parent 44874a0bf9
commit 69672f116c
5 changed files with 107 additions and 61 deletions

View File

@@ -1,3 +1,5 @@
from typing import Callable, Dict, List, Tuple
import torch
from vllm.multimodal import BatchedTensors
@@ -39,3 +41,57 @@ def merge_vision_embeddings(input_ids: torch.Tensor,
inputs_embeds[mask] = torch.cat(vision_embeddings)
return inputs_embeds
class PPMissingLayer(torch.nn.Identity):
"""
A placeholder layer for missing layers in a pipeline parallel model.
"""
def __init__(self, *args, **kwargs):
super().__init__()
def make_layers(
num_hidden_layers: int, layer_fn: Callable[[], torch.nn.Module]
) -> Tuple[int, int, torch.nn.ModuleList]:
"""Make a list of layers with the given layer function, taking
pipeline parallelism into account.
"""
from vllm.distributed.parallel_state import get_pp_group
from vllm.distributed.utils import get_pp_indices
start_layer, end_layer = get_pp_indices(num_hidden_layers,
get_pp_group().rank_in_group,
get_pp_group().world_size)
modules = torch.nn.ModuleList(
[PPMissingLayer() for _ in range(start_layer)] +
[layer_fn() for _ in range(start_layer, end_layer)] +
[PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
return start_layer, end_layer, modules
# NOTE: don't use lru_cache here because it can prevent garbage collection
_model_to_pp_missing_layer_names: Dict[int, List[str]] = {}
def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]:
"""Get the names of the missing layers in a pipeline parallel model."""
model_id = id(model)
if model_id in _model_to_pp_missing_layer_names:
return _model_to_pp_missing_layer_names[model_id]
missing_layer_names = []
for name, module in model.named_modules():
if isinstance(module, PPMissingLayer):
missing_layer_names.append(name)
_model_to_pp_missing_layer_names[model_id] = missing_layer_names
return missing_layer_names
def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
"""Check if a parameter is missing in a pipeline parallel model."""
for missing_layer_name in get_pp_missing_layer_names(model):
if name.startswith(missing_layer_name):
return True
return False