[ci][distributed] try to fix pp test (#7054)

This commit is contained in:
youkaichao
2024-08-01 22:03:12 -07:00
committed by GitHub
parent 3bb4b1e4cd
commit 252357793d
4 changed files with 45 additions and 3 deletions

View File

@@ -9,7 +9,7 @@ import os
import pytest
from ..utils import compare_two_settings
from ..utils import compare_two_settings, fork_new_process_for_each_test
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
@@ -28,6 +28,7 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
])
@fork_new_process_for_each_test
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
DIST_BACKEND):
if VLLM_MULTI_NODE and DIST_BACKEND == "mp":
@@ -77,6 +78,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
"FLASH_ATTN",
"FLASHINFER",
])
@fork_new_process_for_each_test
def test_pp_cudagraph(PP_SIZE, MODEL_NAME, ATTN_BACKEND):
cudagraph_args = [
# use half precision for speed and memory savings in CI environment