[CI/Build][AMD] Fix import errors in tests/kernels/attention (#29032)
Signed-off-by: Randall Smith <ransmith@amd.com> Co-authored-by: Randall Smith <ransmith@amd.com>
This commit is contained in:
@@ -3,7 +3,6 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
||||
from torch import Tensor
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
@@ -15,6 +14,8 @@ if not current_platform.has_device_capability(100):
|
||||
reason="FlashInfer MLA Requires compute capability of 10 or above.",
|
||||
allow_module_level=True,
|
||||
)
|
||||
else:
|
||||
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
|
||||
|
||||
|
||||
def ref_mla(
|
||||
|
||||
Reference in New Issue
Block a user