[ci][distributed] try to fix pp test (#7054)
This commit is contained in:
@@ -9,7 +9,7 @@ import os
|
||||
|
||||
import pytest
|
||||
|
||||
from ..utils import compare_two_settings
|
||||
from ..utils import compare_two_settings, fork_new_process_for_each_test
|
||||
|
||||
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
|
||||
|
||||
@@ -28,6 +28,7 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
|
||||
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"),
|
||||
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"),
|
||||
])
|
||||
@fork_new_process_for_each_test
|
||||
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
|
||||
DIST_BACKEND):
|
||||
if VLLM_MULTI_NODE and DIST_BACKEND == "mp":
|
||||
@@ -77,6 +78,7 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
|
||||
"FLASH_ATTN",
|
||||
"FLASHINFER",
|
||||
])
|
||||
@fork_new_process_for_each_test
|
||||
def test_pp_cudagraph(PP_SIZE, MODEL_NAME, ATTN_BACKEND):
|
||||
cudagraph_args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
|
||||
Reference in New Issue
Block a user