[CI/Build][AMD] Fix import errors in tests/kernels/attention (#29032)
Signed-off-by: Randall Smith <ransmith@amd.com> Co-authored-by: Randall Smith <ransmith@amd.com>
This commit is contained in:
@@ -2,12 +2,20 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import flashinfer
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
try:
|
||||
import flashinfer
|
||||
except ImportError:
|
||||
if current_platform.is_rocm():
|
||||
pytest.skip(
|
||||
"flashinfer is not supported for vLLM on ROCm.", allow_module_level=True
|
||||
)
|
||||
|
||||
import torch
|
||||
|
||||
NUM_HEADS = [(32, 8), (6, 1)]
|
||||
HEAD_SIZES = [128, 256]
|
||||
BLOCK_SIZES = [16, 32]
|
||||
|
||||
Reference in New Issue
Block a user