[core][distributed] simplify code to support pipeline parallel (#6406)
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
from typing import Callable, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.multimodal import BatchedTensors
|
||||
@@ -39,3 +41,57 @@ def merge_vision_embeddings(input_ids: torch.Tensor,
|
||||
inputs_embeds[mask] = torch.cat(vision_embeddings)
|
||||
|
||||
return inputs_embeds
|
||||
|
||||
|
||||
class PPMissingLayer(torch.nn.Identity):
|
||||
"""
|
||||
A placeholder layer for missing layers in a pipeline parallel model.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
|
||||
def make_layers(
|
||||
num_hidden_layers: int, layer_fn: Callable[[], torch.nn.Module]
|
||||
) -> Tuple[int, int, torch.nn.ModuleList]:
|
||||
"""Make a list of layers with the given layer function, taking
|
||||
pipeline parallelism into account.
|
||||
"""
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.distributed.utils import get_pp_indices
|
||||
start_layer, end_layer = get_pp_indices(num_hidden_layers,
|
||||
get_pp_group().rank_in_group,
|
||||
get_pp_group().world_size)
|
||||
modules = torch.nn.ModuleList(
|
||||
[PPMissingLayer() for _ in range(start_layer)] +
|
||||
[layer_fn() for _ in range(start_layer, end_layer)] +
|
||||
[PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
|
||||
return start_layer, end_layer, modules
|
||||
|
||||
|
||||
# NOTE: don't use lru_cache here because it can prevent garbage collection
|
||||
_model_to_pp_missing_layer_names: Dict[int, List[str]] = {}
|
||||
|
||||
|
||||
def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]:
|
||||
"""Get the names of the missing layers in a pipeline parallel model."""
|
||||
model_id = id(model)
|
||||
if model_id in _model_to_pp_missing_layer_names:
|
||||
return _model_to_pp_missing_layer_names[model_id]
|
||||
|
||||
missing_layer_names = []
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, PPMissingLayer):
|
||||
missing_layer_names.append(name)
|
||||
_model_to_pp_missing_layer_names[model_id] = missing_layer_names
|
||||
|
||||
return missing_layer_names
|
||||
|
||||
|
||||
def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
|
||||
"""Check if a parameter is missing in a pipeline parallel model."""
|
||||
for missing_layer_name in get_pp_missing_layer_names(model):
|
||||
if name.startswith(missing_layer_name):
|
||||
return True
|
||||
return False
|
||||
|
||||
Reference in New Issue
Block a user