diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 610891ebf..7228d92f7 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -31,7 +31,12 @@ IS_AITER_FOUND = is_aiter_found() def is_aiter_found_and_supported() -> bool: - if current_platform.is_rocm() and IS_AITER_FOUND: + """Check if AITER is available AND enabled via environment variable. + + Checks: platform (ROCm), device arch (gfx9), library existence, + and VLLM_ROCM_USE_AITER env variable. + """ + if current_platform.is_rocm() and IS_AITER_FOUND and envs.VLLM_ROCM_USE_AITER: from vllm.platforms.rocm import on_gfx9 return on_gfx9() @@ -40,13 +45,11 @@ def is_aiter_found_and_supported() -> bool: def if_aiter_supported(func: Callable) -> Callable: """Decorator that only executes the function if - ROCm AITER package is supported on gfx9 archs. + ROCm AITER package is supported and enabled on gfx9 archs. """ @functools.wraps(func) def wrapper(*args, **kwargs): - # checks the platform, device arch and aiter library existence. - if is_aiter_found_and_supported(): return func(*args, **kwargs) @@ -63,6 +66,11 @@ if is_aiter_found_and_supported(): from aiter import dtypes AITER_FP8_DTYPE = dtypes.fp8 +else: + # Placeholder when AITER is disabled - prevents NameError during module load. + # Note: When AITER is disabled, ops are not registered, so fake implementations + # referencing this variable won't actually be called at runtime. + AITER_FP8_DTYPE = _FP8_DTYPE def _rocm_aiter_fused_moe_impl( diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index 8ed06fc8d..18194e05f 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -32,10 +32,11 @@ from vllm.v1.kv_cache_interface import AttentionSpec _PARTITION_SIZE_ROCM = 256 _CP_TOKENS_PER_ITER_ROCM = 32 * 1024 if current_platform.is_rocm(): - import aiter - from vllm.triton_utils import tl, triton + if rocm_aiter_ops.is_enabled(): + import aiter + def block_size(x, head_dim): return min(65536 // x.element_size(), triton.next_power_of_2(head_dim)) diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 43b84f4be..45680a796 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -183,7 +183,9 @@ class SpecDecodeBaseProposer: RocmAttentionMetadata, ] # ROCM_AITER_FA is an optional backend - if find_spec( + from vllm._aiter_ops import rocm_aiter_ops + + if rocm_aiter_ops.is_enabled() and find_spec( AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False) ): from vllm.v1.attention.backends.rocm_aiter_fa import (