[Model] PP support for embedding models and update docs (#9090)

Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
This commit is contained in:
Cyrus Leung
2024-10-06 16:35:27 +08:00
committed by GitHub
parent f22619fe96
commit b22b798471
12 changed files with 612 additions and 451 deletions

View File

@@ -306,10 +306,12 @@ def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]:
def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
"""Check if a parameter is missing in a pipeline parallel model."""
for missing_layer_name in get_pp_missing_layer_names(model):
if name.startswith(missing_layer_name):
return True
return False
if isinstance(model, PPMissingLayer):
return True
return any(
name.startswith(missing_layer_name)
for missing_layer_name in get_pp_missing_layer_names(model))
def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):