[core][distributed] support n layers % pp size != 0 (#6115)

This commit is contained in:
youkaichao
2024-07-03 16:40:31 -07:00
committed by GitHub
parent 966fe72141
commit 3de6e6a30e
7 changed files with 19 additions and 10 deletions

View File

@@ -265,8 +265,6 @@ class ModelConfig:
" must be divisible by tensor parallel size "
f"({tensor_parallel_size}).")
total_num_hidden_layers = getattr(self.hf_text_config,
"num_hidden_layers", 0)
pipeline_parallel_size = parallel_config.pipeline_parallel_size
architectures = getattr(self.hf_config, "architectures", [])
if not all(arch in _PP_SUPPORTED_MODELS
@@ -275,12 +273,6 @@ class ModelConfig:
"Pipeline parallelism is only supported for the following "
f" architectures: {_PP_SUPPORTED_MODELS}.")
if total_num_hidden_layers % pipeline_parallel_size != 0:
raise ValueError(
f"Total number of hidden layers ({total_num_hidden_layers}) "
"must be divisible by pipeline parallel size "
f"({pipeline_parallel_size}).")
if self.quantization == "bitsandbytes" and (
parallel_config.tensor_parallel_size > 1
or parallel_config.pipeline_parallel_size > 1):
@@ -385,9 +377,13 @@ class ModelConfig:
return num_heads // parallel_config.tensor_parallel_size
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
from vllm.distributed.utils import get_pp_indices
total_num_hidden_layers = getattr(self.hf_text_config,
"num_hidden_layers", 0)
return total_num_hidden_layers // parallel_config.pipeline_parallel_size
pp_rank = parallel_config.rank // parallel_config.tensor_parallel_size
pp_size = parallel_config.pipeline_parallel_size
start, end = get_pp_indices(total_num_hidden_layers, pp_rank, pp_size)
return end - start
def contains_seqlen_agnostic_layers(
self, parallel_config: "ParallelConfig") -> bool:
@@ -709,6 +705,7 @@ class ParallelConfig:
{"CUDA_VISIBLE_DEVICES": envs.CUDA_VISIBLE_DEVICES})
self._verify_args()
self.rank = 0
def _verify_args(self) -> None:
if (self.pipeline_parallel_size > 1