[ROCm][CI] Fix tests/compile unit tests (#28895)

Signed-off-by: charlifu <charlifu@amd.com>
Signed-off-by: Micah Williamson <micah.williamson@amd.com>
Signed-off-by: Charlie Fu <Charlie.Fu@amd.com>
Co-authored-by: Micah Williamson <micah.williamson@amd.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Charlie Fu
2026-01-06 12:50:43 -06:00
committed by GitHub
parent f7008ce1c4
commit c07163663d
3 changed files with 30 additions and 13 deletions

View File

@@ -5,10 +5,13 @@ import dataclasses
import pytest
from vllm.config import CompilationMode
from vllm.platforms import current_platform
from vllm.utils.torch_utils import cuda_device_count_stateless
from ...utils import compare_all_settings
ATTN_BACKEND = "FLASH_ATTN" if not current_platform.is_rocm() else "ROCM_ATTN"
@dataclasses.dataclass
class TestSetting:
@@ -31,7 +34,7 @@ class TestSetting:
model_args=["--max-model-len", "2048"],
pp_size=2,
tp_size=2,
attn_backend="FLASH_ATTN",
attn_backend=ATTN_BACKEND,
method="generate",
),
# llama model with quantization
@@ -40,7 +43,7 @@ class TestSetting:
model_args=["--quantization", "gptq", "--max-model-len", "2048"],
pp_size=1,
tp_size=1,
attn_backend="FLASH_ATTN",
attn_backend=ATTN_BACKEND,
method="generate",
),
# MoE model
@@ -49,7 +52,7 @@ class TestSetting:
model_args=["--max-model-len", "2048"],
pp_size=1,
tp_size=2,
attn_backend="FLASH_ATTN",
attn_backend=ATTN_BACKEND,
method="generate",
),
# embedding model
@@ -65,16 +68,22 @@ class TestSetting:
],
pp_size=1,
tp_size=1,
attn_backend="FLASH_ATTN",
attn_backend=ATTN_BACKEND,
method="encode",
),
TestSetting(
model="BAAI/bge-base-en-v1.5",
model_args=["--runner", "pooling"],
pp_size=1,
tp_size=1,
attn_backend="FLASH_ATTN",
method="encode",
pytest.param(
TestSetting(
model="BAAI/bge-base-en-v1.5",
model_args=["--runner", "pooling"],
pp_size=1,
tp_size=1,
attn_backend="FLASH_ATTN",
method="encode",
),
marks=pytest.mark.skipif(
current_platform.is_rocm(),
reason="Encoder self-attention is not implemented for ROCm",
),
),
# vision language model
# See https://github.com/vllm-project/vllm/issues/26716.