[Core] Pipeline Parallel Support (#4412)
Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
This commit is contained in:
committed by
GitHub
parent
15aba081f3
commit
c5832d2ae9
@@ -27,6 +27,17 @@ logger = init_logger(__name__)
|
||||
_GB = 1 << 30
|
||||
_EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768
|
||||
|
||||
_PP_SUPPORTED_MODELS = [
|
||||
"AquilaModel",
|
||||
"AquilaForCausalLM",
|
||||
"InternLMForCausalLM",
|
||||
"LlamaForCausalLM",
|
||||
"LLaMAForCausalLM",
|
||||
"MistralForCausalLM",
|
||||
"Phi3ForCausalLM",
|
||||
"GPT2LMHeadModel",
|
||||
]
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
"""Configuration for the model.
|
||||
@@ -258,6 +269,13 @@ class ModelConfig:
|
||||
total_num_hidden_layers = getattr(self.hf_text_config,
|
||||
"num_hidden_layers", 0)
|
||||
pipeline_parallel_size = parallel_config.pipeline_parallel_size
|
||||
architectures = getattr(self.hf_config, "architectures", [])
|
||||
if not all(arch in _PP_SUPPORTED_MODELS
|
||||
for arch in architectures) and pipeline_parallel_size > 1:
|
||||
raise NotImplementedError(
|
||||
"Pipeline parallelism is only supported for the following "
|
||||
f" architectures: {_PP_SUPPORTED_MODELS}.")
|
||||
|
||||
if total_num_hidden_layers % pipeline_parallel_size != 0:
|
||||
raise ValueError(
|
||||
f"Total number of hidden layers ({total_num_hidden_layers}) "
|
||||
@@ -665,9 +683,10 @@ class ParallelConfig:
|
||||
self._verify_args()
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
if self.pipeline_parallel_size > 1:
|
||||
raise NotImplementedError(
|
||||
"Pipeline parallelism is not supported yet.")
|
||||
if (self.pipeline_parallel_size > 1
|
||||
and self.distributed_executor_backend == "mp"):
|
||||
raise NotImplementedError("Pipeline parallelism is not supported "
|
||||
"yet with multiprocessing.")
|
||||
if self.distributed_executor_backend not in ("ray", "mp", None):
|
||||
raise ValueError(
|
||||
"Unrecognized distributed executor backend. Supported values "
|
||||
|
||||
Reference in New Issue
Block a user