[Core] Multiprocessing Pipeline Parallel support (#6130)

Co-authored-by: Murali Andoorveedu <muralidhar.andoorveedu@centml.ai>
This commit is contained in:
Nick Hill
2024-07-18 19:15:52 -07:00
committed by GitHub
parent c5df56f88b
commit b5672a112c
9 changed files with 152 additions and 99 deletions

View File

@@ -1,28 +1,42 @@
import os
import pytest
from ..utils import compare_two_settings
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
@pytest.mark.parametrize(
"TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME", [
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B"),
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B"),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B"),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B"),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B"),
"TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, DIST_BACKEND",
[
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"),
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"),
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
])
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
DIST_BACKEND):
if VLLM_MULTI_NODE and DIST_BACKEND == "mp":
pytest.skip("Skipping multi-node pipeline parallel test for "
"multiprocessing distributed backend")
pp_args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"bfloat16",
"float16",
"--pipeline-parallel-size",
str(PP_SIZE),
"--tensor-parallel-size",
str(TP_SIZE),
"--distributed-executor-backend",
"ray",
DIST_BACKEND,
]
# compare without pipeline parallelism