diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index c7f925817..ad454daa5 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -18,6 +18,7 @@ from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.utils.math_utils import cdiv +from vllm.utils.platform_utils import get_cu_count from vllm.v1.attention.backends.utils import ( AttentionCGSupport, AttentionMetadataBuilder, @@ -38,7 +39,7 @@ if current_platform.is_rocm(): return min(65536 // x.element_size(), triton.next_power_of_2(head_dim)) def num_programs(total_tokens): - return min(total_tokens, current_platform.get_cu_count()) + return min(total_tokens, get_cu_count()) @triton.jit def cp_mha_gather_cache_kernel(