[Model] Pipeline parallel support for Mixtral (#6516)

This commit is contained in:
Cody Yu
2024-07-17 19:26:04 -07:00
committed by GitHub
parent b5241e41d9
commit b5af8c223c
3 changed files with 60 additions and 19 deletions

View File

@@ -1,4 +1,5 @@
import pytest
from transformers import AutoTokenizer
from ..utils import RemoteOpenAIServer
@@ -12,6 +13,8 @@ from ..utils import RemoteOpenAIServer
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B"),
])
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
pp_args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
@@ -34,7 +37,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
"--dtype",
"bfloat16",
"--tensor-parallel-size",
str(max(TP_SIZE, 2)), # use at least TP_SIZE=2 to hold the model
str(max(TP_SIZE, 2)), # We only use 2 GPUs in the CI.
"--distributed-executor-backend",
"mp",
]
@@ -45,8 +48,10 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
pp_args.append("--enforce-eager")
tp_args.append("--enforce-eager")
prompt = "Hello, my name is"
token_ids = tokenizer(prompt)["input_ids"]
results = []
for args in [pp_args, tp_args]:
for args in (pp_args, tp_args):
with RemoteOpenAIServer(MODEL_NAME, args) as server:
client = server.get_client()
@@ -62,7 +67,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
# test with text prompt
completion = client.completions.create(model=MODEL_NAME,
prompt="Hello, my name is",
prompt=prompt,
max_tokens=5,
temperature=0.0)
@@ -76,7 +81,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
# test using token IDs
completion = client.completions.create(
model=MODEL_NAME,
prompt=[0, 0, 0, 0, 0],
prompt=token_ids,
max_tokens=5,
temperature=0.0,
)
@@ -91,7 +96,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
# test simple list
batch = client.completions.create(
model=MODEL_NAME,
prompt=["Hello, my name is", "Hello, my name is"],
prompt=[prompt, prompt],
max_tokens=5,
temperature=0.0,
)
@@ -105,7 +110,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME):
# test streaming
batch = client.completions.create(
model=MODEL_NAME,
prompt=["Hello, my name is", "Hello, my name is"],
prompt=[prompt, prompt],
max_tokens=5,
temperature=0.0,
stream=True,