[Hardware][ROCM] using current_platform.is_rocm (#9642)
Signed-off-by: wangshuai09 <391746016@qq.com>
This commit is contained in:
@@ -6,11 +6,12 @@ import torch
|
||||
|
||||
from tests.kernels.utils import opcheck
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import get_max_shared_memory_bytes, is_hip, seed_everything
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import get_max_shared_memory_bytes, seed_everything
|
||||
|
||||
from .allclose_default import get_default_atol, get_default_rtol
|
||||
|
||||
if not is_hip():
|
||||
if not current_platform.is_rocm():
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||
|
||||
@@ -23,8 +24,9 @@ MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
|
||||
NUM_BLOCKS = 4321 # Arbitrary values for testing
|
||||
PARTITION_SIZE = 512
|
||||
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float
|
||||
] if not is_hip() else [torch.half, torch.bfloat16]
|
||||
DTYPES = [
|
||||
torch.half, torch.bfloat16, torch.float
|
||||
] if not current_platform.is_rocm() else [torch.half, torch.bfloat16]
|
||||
NUM_GEN_SEQS = [7] # Arbitrary values for testing
|
||||
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
|
||||
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
|
||||
@@ -114,7 +116,8 @@ def ref_single_query_cached_kv_attention(
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"version", ["v1", "v2"] if not is_hip() else ["v1", "v2", "rocm"])
|
||||
"version",
|
||||
["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"])
|
||||
@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@@ -317,8 +320,8 @@ def test_paged_attention(
|
||||
# NOTE(woosuk): Due to the kernel-level differences in the two
|
||||
# implementations, there is a small numerical difference in the two
|
||||
# outputs. Thus, we use a relaxed tolerance for the test.
|
||||
atol = get_default_atol(output) if is_hip() else 1e-3
|
||||
rtol = get_default_rtol(output) if is_hip() else 1e-5
|
||||
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
|
||||
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
|
||||
|
||||
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
|
||||
# so we use a relaxed tolerance for the test.
|
||||
@@ -368,7 +371,7 @@ def ref_multi_query_kv_attention(
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.skipif(is_hip(),
|
||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||
reason="Xformers backend is not supported on ROCm.")
|
||||
@torch.inference_mode()
|
||||
def test_multi_query_kv_attention(
|
||||
@@ -425,6 +428,6 @@ def test_multi_query_kv_attention(
|
||||
scale,
|
||||
dtype,
|
||||
)
|
||||
atol = get_default_atol(output) if is_hip() else 1e-3
|
||||
rtol = get_default_rtol(output) if is_hip() else 1e-5
|
||||
atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3
|
||||
rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5
|
||||
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol)
|
||||
|
||||
Reference in New Issue
Block a user