[ROCm] Fix some kernels failed unit tests (#2498)
This commit is contained in:
@@ -8,6 +8,8 @@ from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||
|
||||
from vllm._C import ops, cache_ops
|
||||
from vllm.utils import get_max_shared_memory_bytes
|
||||
from vllm.utils import is_hip
|
||||
from allclose_default import get_default_atol, get_default_rtol
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
# This will change depending on the compute capability.
|
||||
@@ -17,12 +19,18 @@ MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512
|
||||
# Reduce NUM_BLOCKS when it happens.
|
||||
NUM_BLOCKS = 4321 # Arbitrary values for testing
|
||||
PARTITION_SIZE = 512
|
||||
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float]
|
||||
# flshattF and tritonflashattF supported: {torch.float16, torch.bfloat16}
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float
|
||||
] if not is_hip() else [torch.half, torch.bfloat16]
|
||||
NUM_GEN_SEQS = [7] # Arbitrary values for testing
|
||||
NUM_PREFILL_SEQS = [3] # Arbitrary values for testing
|
||||
NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing
|
||||
HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||
|
||||
# FlashAttention forward only supports head dimension at most 128
|
||||
# https://github.com/ROCmSoftwarePlatform/flash-attention/blob/3d2b6f5d037782cc2c906909a46fb7e2e1b48b25/csrc/flash_attn_rocm/flash_api.cpp#L62
|
||||
HEAD_SIZES = [64, 80, 96, 112, 128, 256
|
||||
] if not is_hip() else [64, 80, 96, 112, 128]
|
||||
|
||||
BLOCK_SIZES = [16, 32]
|
||||
USE_ALIBI = [False, True]
|
||||
KV_CACHE_DTYPE = ["auto", "fp8_e5m2"]
|
||||
@@ -251,9 +259,11 @@ def test_paged_attention(
|
||||
# NOTE(woosuk): Due to the kernel-level differences in the two
|
||||
# implementations, there is a small numerical difference in the two
|
||||
# outputs. Thus, we use a relaxed tolerance for the test.
|
||||
atol = get_default_atol(output) if is_hip() else 1e-3
|
||||
rtol = get_default_rtol(output) if is_hip() else 1e-5
|
||||
|
||||
# NOTE(zhaoyang): FP8 KV Cache will introduce quantization error,
|
||||
# so we use a relaxed tolerance for the test.
|
||||
atol, rtol = 1e-3, 1e-5
|
||||
if kv_cache_dtype == "fp8_e5m2":
|
||||
atol, rtol = 1e-2, 1e-5
|
||||
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
|
||||
@@ -357,4 +367,6 @@ def test_multi_query_kv_attention(
|
||||
scale,
|
||||
dtype,
|
||||
)
|
||||
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)
|
||||
atol = get_default_atol(output) if is_hip() else 1e-3
|
||||
rtol = get_default_rtol(output) if is_hip() else 1e-5
|
||||
assert torch.allclose(output, ref_output, atol=atol, rtol=rtol)
|
||||
|
||||
Reference in New Issue
Block a user