Signed-off-by: Lucas Kabela <lucaskabela@meta.com>
This commit is contained in:
@@ -20,7 +20,6 @@ class TestSetting:
|
||||
tp_size: int
|
||||
attn_backend: str
|
||||
method: str
|
||||
fullgraph: bool
|
||||
|
||||
|
||||
# we cannot afford testing the full Cartesian product
|
||||
@@ -36,7 +35,6 @@ class TestSetting:
|
||||
tp_size=2,
|
||||
attn_backend="FLASH_ATTN",
|
||||
method="generate",
|
||||
fullgraph=True,
|
||||
),
|
||||
# llama model with quantization
|
||||
TestSetting(
|
||||
@@ -46,7 +44,6 @@ class TestSetting:
|
||||
tp_size=1,
|
||||
attn_backend="FLASH_ATTN",
|
||||
method="generate",
|
||||
fullgraph=True,
|
||||
),
|
||||
# MoE model
|
||||
TestSetting(
|
||||
@@ -56,7 +53,6 @@ class TestSetting:
|
||||
tp_size=2,
|
||||
attn_backend="FLASH_ATTN",
|
||||
method="generate",
|
||||
fullgraph=True,
|
||||
),
|
||||
# embedding model
|
||||
TestSetting(
|
||||
@@ -73,7 +69,6 @@ class TestSetting:
|
||||
tp_size=1,
|
||||
attn_backend="FLASH_ATTN",
|
||||
method="encode",
|
||||
fullgraph=True,
|
||||
),
|
||||
TestSetting(
|
||||
model="BAAI/bge-base-en-v1.5",
|
||||
@@ -82,7 +77,6 @@ class TestSetting:
|
||||
tp_size=1,
|
||||
attn_backend="FLASH_ATTN",
|
||||
method="encode",
|
||||
fullgraph=True,
|
||||
),
|
||||
# vision language model
|
||||
TestSetting(
|
||||
@@ -92,7 +86,6 @@ class TestSetting:
|
||||
tp_size=1,
|
||||
attn_backend="FLASH_ATTN",
|
||||
method="generate_with_image",
|
||||
fullgraph=False,
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -109,9 +102,8 @@ def test_compile_correctness(
|
||||
tp_size = test_setting.tp_size
|
||||
attn_backend = test_setting.attn_backend
|
||||
method = test_setting.method
|
||||
fullgraph = test_setting.fullgraph
|
||||
if cuda_device_count_stateless() != pp_size * tp_size:
|
||||
pytest.skip(f"Need exactly {pp_size}*{tp_size} CUDA gpus but got "
|
||||
if cuda_device_count_stateless() < pp_size * tp_size:
|
||||
pytest.skip(f"Need at least {pp_size}*{tp_size} CUDA gpus but got "
|
||||
f"{cuda_device_count_stateless()}")
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
@@ -149,9 +141,5 @@ def test_compile_correctness(
|
||||
]:
|
||||
all_args.append(final_args + [f"-O{level}"])
|
||||
all_envs.append({})
|
||||
if level != CompilationLevel.DYNAMO_ONCE and not fullgraph:
|
||||
# "DYNAMO_ONCE" will always use fullgraph
|
||||
all_envs[-1][
|
||||
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "0" # type: ignore
|
||||
|
||||
compare_all_settings(model, all_args * 3, all_envs, method=method)
|
||||
|
||||
Reference in New Issue
Block a user