[ROCm][CI] Fix tests/compile unit tests (#28895)
Signed-off-by: charlifu <charlifu@amd.com> Signed-off-by: Micah Williamson <micah.williamson@amd.com> Signed-off-by: Charlie Fu <Charlie.Fu@amd.com> Co-authored-by: Micah Williamson <micah.williamson@amd.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -5,10 +5,13 @@ import dataclasses
|
||||
import pytest
|
||||
|
||||
from vllm.config import CompilationMode
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
from ...utils import compare_all_settings
|
||||
|
||||
ATTN_BACKEND = "FLASH_ATTN" if not current_platform.is_rocm() else "ROCM_ATTN"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class TestSetting:
|
||||
@@ -31,7 +34,7 @@ class TestSetting:
|
||||
model_args=["--max-model-len", "2048"],
|
||||
pp_size=2,
|
||||
tp_size=2,
|
||||
attn_backend="FLASH_ATTN",
|
||||
attn_backend=ATTN_BACKEND,
|
||||
method="generate",
|
||||
),
|
||||
# llama model with quantization
|
||||
@@ -40,7 +43,7 @@ class TestSetting:
|
||||
model_args=["--quantization", "gptq", "--max-model-len", "2048"],
|
||||
pp_size=1,
|
||||
tp_size=1,
|
||||
attn_backend="FLASH_ATTN",
|
||||
attn_backend=ATTN_BACKEND,
|
||||
method="generate",
|
||||
),
|
||||
# MoE model
|
||||
@@ -49,7 +52,7 @@ class TestSetting:
|
||||
model_args=["--max-model-len", "2048"],
|
||||
pp_size=1,
|
||||
tp_size=2,
|
||||
attn_backend="FLASH_ATTN",
|
||||
attn_backend=ATTN_BACKEND,
|
||||
method="generate",
|
||||
),
|
||||
# embedding model
|
||||
@@ -65,16 +68,22 @@ class TestSetting:
|
||||
],
|
||||
pp_size=1,
|
||||
tp_size=1,
|
||||
attn_backend="FLASH_ATTN",
|
||||
attn_backend=ATTN_BACKEND,
|
||||
method="encode",
|
||||
),
|
||||
TestSetting(
|
||||
model="BAAI/bge-base-en-v1.5",
|
||||
model_args=["--runner", "pooling"],
|
||||
pp_size=1,
|
||||
tp_size=1,
|
||||
attn_backend="FLASH_ATTN",
|
||||
method="encode",
|
||||
pytest.param(
|
||||
TestSetting(
|
||||
model="BAAI/bge-base-en-v1.5",
|
||||
model_args=["--runner", "pooling"],
|
||||
pp_size=1,
|
||||
tp_size=1,
|
||||
attn_backend="FLASH_ATTN",
|
||||
method="encode",
|
||||
),
|
||||
marks=pytest.mark.skipif(
|
||||
current_platform.is_rocm(),
|
||||
reason="Encoder self-attention is not implemented for ROCm",
|
||||
),
|
||||
),
|
||||
# vision language model
|
||||
# See https://github.com/vllm-project/vllm/issues/26716.
|
||||
|
||||
@@ -9,6 +9,7 @@ import pytest
|
||||
from tests.utils import wait_for_gpu_memory_to_clear
|
||||
from tests.v1.attention.utils import full_cg_backend_configs as backend_configs
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.config import CompilationConfig
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
@@ -70,6 +71,10 @@ def llm_pair(request):
|
||||
elif backend_config.specific_gpu_arch == (10, 0):
|
||||
pytest.skip("Only Blackwell GPUs support Cutlass MLA")
|
||||
|
||||
# FlashInfer is not supported on ROCm
|
||||
if backend_config == AttentionBackendEnum.FLASHINFER and current_platform.is_rocm():
|
||||
pytest.skip("FlashInfer is not supported on ROCm")
|
||||
|
||||
env_vars = {
|
||||
# Force native sampler to avoid potential nondeterminism in FlashInfer
|
||||
# when per-request generators are not used in V1.
|
||||
|
||||
@@ -25,10 +25,13 @@ def test_noop_elimination(dtype, num_tokens, hidden_size, buffer_size):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.pos_embed = torch.empty(buffer_size, hidden_size, dtype=dtype)
|
||||
# Avoid using empty, since on rocm torch.empty
|
||||
# does not initialize the memory.
|
||||
self.pos_embed = torch.randn(buffer_size, hidden_size, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
x += self.pos_embed[: x.shape[0]]
|
||||
# Avoid += to prevent inplace addition.
|
||||
x = x + self.pos_embed[: x.shape[0]]
|
||||
# Chain of reshapes
|
||||
y = x.reshape(-1, 128, 32)
|
||||
z = y.reshape(-1, 4096)
|
||||
|
||||
Reference in New Issue
Block a user