fix[ROCm]: Remove unconditional aiter import (#32902)

Signed-off-by: rabi <ramishra@redhat.com>
This commit is contained in:
Rabi Mishra
2026-02-02 19:40:02 +05:30
committed by GitHub
parent b10d05b8a8
commit 9eb58f8cf1
3 changed files with 18 additions and 7 deletions

View File

@@ -31,7 +31,12 @@ IS_AITER_FOUND = is_aiter_found()
def is_aiter_found_and_supported() -> bool:
if current_platform.is_rocm() and IS_AITER_FOUND:
"""Check if AITER is available AND enabled via environment variable.
Checks: platform (ROCm), device arch (gfx9), library existence,
and VLLM_ROCM_USE_AITER env variable.
"""
if current_platform.is_rocm() and IS_AITER_FOUND and envs.VLLM_ROCM_USE_AITER:
from vllm.platforms.rocm import on_gfx9
return on_gfx9()
@@ -40,13 +45,11 @@ def is_aiter_found_and_supported() -> bool:
def if_aiter_supported(func: Callable) -> Callable:
"""Decorator that only executes the function if
ROCm AITER package is supported on gfx9 archs.
ROCm AITER package is supported and enabled on gfx9 archs.
"""
@functools.wraps(func)
def wrapper(*args, **kwargs):
# checks the platform, device arch and aiter library existence.
if is_aiter_found_and_supported():
return func(*args, **kwargs)
@@ -63,6 +66,11 @@ if is_aiter_found_and_supported():
from aiter import dtypes
AITER_FP8_DTYPE = dtypes.fp8
else:
# Placeholder when AITER is disabled - prevents NameError during module load.
# Note: When AITER is disabled, ops are not registered, so fake implementations
# referencing this variable won't actually be called at runtime.
AITER_FP8_DTYPE = _FP8_DTYPE
def _rocm_aiter_fused_moe_impl(

View File

@@ -32,10 +32,11 @@ from vllm.v1.kv_cache_interface import AttentionSpec
_PARTITION_SIZE_ROCM = 256
_CP_TOKENS_PER_ITER_ROCM = 32 * 1024
if current_platform.is_rocm():
import aiter
from vllm.triton_utils import tl, triton
if rocm_aiter_ops.is_enabled():
import aiter
def block_size(x, head_dim):
return min(65536 // x.element_size(), triton.next_power_of_2(head_dim))

View File

@@ -183,7 +183,9 @@ class SpecDecodeBaseProposer:
RocmAttentionMetadata,
]
# ROCM_AITER_FA is an optional backend
if find_spec(
from vllm._aiter_ops import rocm_aiter_ops
if rocm_aiter_ops.is_enabled() and find_spec(
AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False)
):
from vllm.v1.attention.backends.rocm_aiter_fa import (