[Model] PP support for embedding models and update docs (#9090)
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
This commit is contained in:
@@ -306,10 +306,12 @@ def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]:
|
||||
|
||||
def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
|
||||
"""Check if a parameter is missing in a pipeline parallel model."""
|
||||
for missing_layer_name in get_pp_missing_layer_names(model):
|
||||
if name.startswith(missing_layer_name):
|
||||
return True
|
||||
return False
|
||||
if isinstance(model, PPMissingLayer):
|
||||
return True
|
||||
|
||||
return any(
|
||||
name.startswith(missing_layer_name)
|
||||
for missing_layer_name in get_pp_missing_layer_names(model))
|
||||
|
||||
|
||||
def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int):
|
||||
|
||||
Reference in New Issue
Block a user