[CI/Build][AMD] Fix import errors in tests/kernels/attention (#29032)

Signed-off-by: Randall Smith <ransmith@amd.com>
Co-authored-by: Randall Smith <ransmith@amd.com>
This commit is contained in:
rasmith
2025-11-20 03:48:09 -06:00
committed by GitHub
parent 2c52c7fd9a
commit 322cb02872
6 changed files with 49 additions and 15 deletions

View File

@@ -3,7 +3,6 @@
import pytest
import torch
import torch.nn.functional as F
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
from torch import Tensor
from vllm.platforms import current_platform
@@ -15,6 +14,8 @@ if not current_platform.has_device_capability(100):
reason="FlashInfer MLA Requires compute capability of 10 or above.",
allow_module_level=True,
)
else:
from flashinfer.decode import trtllm_batch_decode_with_kv_cache_mla
def ref_mla(