[Model] Support pp for qwen2-vl (#8696)
This commit is contained in:
@@ -8,6 +8,8 @@ WARNING: This test runs in both single-node (4 GPUs) and multi-node
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from packaging import version
|
||||
from transformers import __version__ as transformers_version
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
@@ -37,6 +39,7 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
|
||||
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-1B", "mp"),
|
||||
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-2B", "mp"),
|
||||
(1, 2, 1, 0, 1, "OpenGVLab/InternVL2-4B", "mp"),
|
||||
(1, 2, 0, 1, 0, "Qwen/Qwen2-VL-2B-Instruct", "mp")
|
||||
],
|
||||
)
|
||||
@fork_new_process_for_each_test
|
||||
@@ -46,6 +49,11 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL,
|
||||
pytest.skip("Skipping multi-node pipeline parallel test for "
|
||||
"multiprocessing distributed backend")
|
||||
|
||||
# Skip tests that require transformers>=4.45.0
|
||||
if "Qwen2-VL" in MODEL_NAME and version.parse(
|
||||
transformers_version) < version.parse("4.45.0.dev0"):
|
||||
pytest.skip("This test requires transformers>=4.45.0")
|
||||
|
||||
pp_args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
|
||||
Reference in New Issue
Block a user