[Model] Support pp for qwen2-vl (#8696)

This commit is contained in:
Yanyi Liu
2024-09-23 21:46:59 +08:00
committed by GitHub
parent 3e83c12b5c
commit a79e522984
4 changed files with 46 additions and 14 deletions

View File

@@ -8,6 +8,8 @@ WARNING: This test runs in both single-node (4 GPUs) and multi-node
import os
import pytest
from packaging import version
from transformers import __version__ as transformers_version
from vllm.logger import init_logger
@@ -37,6 +39,7 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-1B", "mp"),
(1, 2, 1, 1, 1, "OpenGVLab/InternVL2-2B", "mp"),
(1, 2, 1, 0, 1, "OpenGVLab/InternVL2-4B", "mp"),
(1, 2, 0, 1, 0, "Qwen/Qwen2-VL-2B-Instruct", "mp")
],
)
@fork_new_process_for_each_test
@@ -46,6 +49,11 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL,
pytest.skip("Skipping multi-node pipeline parallel test for "
"multiprocessing distributed backend")
# Skip tests that require transformers>=4.45.0
if "Qwen2-VL" in MODEL_NAME and version.parse(
transformers_version) < version.parse("4.45.0.dev0"):
pytest.skip("This test requires transformers>=4.45.0")
pp_args = [
# use half precision for speed and memory savings in CI environment
"--dtype",