fix[ROCm]: Remove unconditional aiter import (#32902)

Signed-off-by: rabi <ramishra@redhat.com>
This commit is contained in:
Rabi Mishra
2026-02-02 19:40:02 +05:30
committed by GitHub
parent b10d05b8a8
commit 9eb58f8cf1
3 changed files with 18 additions and 7 deletions

View File

@@ -31,7 +31,12 @@ IS_AITER_FOUND = is_aiter_found()
def is_aiter_found_and_supported() -> bool: def is_aiter_found_and_supported() -> bool:
if current_platform.is_rocm() and IS_AITER_FOUND: """Check if AITER is available AND enabled via environment variable.
Checks: platform (ROCm), device arch (gfx9), library existence,
and VLLM_ROCM_USE_AITER env variable.
"""
if current_platform.is_rocm() and IS_AITER_FOUND and envs.VLLM_ROCM_USE_AITER:
from vllm.platforms.rocm import on_gfx9 from vllm.platforms.rocm import on_gfx9
return on_gfx9() return on_gfx9()
@@ -40,13 +45,11 @@ def is_aiter_found_and_supported() -> bool:
def if_aiter_supported(func: Callable) -> Callable: def if_aiter_supported(func: Callable) -> Callable:
"""Decorator that only executes the function if """Decorator that only executes the function if
ROCm AITER package is supported on gfx9 archs. ROCm AITER package is supported and enabled on gfx9 archs.
""" """
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
# checks the platform, device arch and aiter library existence.
if is_aiter_found_and_supported(): if is_aiter_found_and_supported():
return func(*args, **kwargs) return func(*args, **kwargs)
@@ -63,6 +66,11 @@ if is_aiter_found_and_supported():
from aiter import dtypes from aiter import dtypes
AITER_FP8_DTYPE = dtypes.fp8 AITER_FP8_DTYPE = dtypes.fp8
else:
# Placeholder when AITER is disabled - prevents NameError during module load.
# Note: When AITER is disabled, ops are not registered, so fake implementations
# referencing this variable won't actually be called at runtime.
AITER_FP8_DTYPE = _FP8_DTYPE
def _rocm_aiter_fused_moe_impl( def _rocm_aiter_fused_moe_impl(

View File

@@ -32,10 +32,11 @@ from vllm.v1.kv_cache_interface import AttentionSpec
_PARTITION_SIZE_ROCM = 256 _PARTITION_SIZE_ROCM = 256
_CP_TOKENS_PER_ITER_ROCM = 32 * 1024 _CP_TOKENS_PER_ITER_ROCM = 32 * 1024
if current_platform.is_rocm(): if current_platform.is_rocm():
import aiter
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
if rocm_aiter_ops.is_enabled():
import aiter
def block_size(x, head_dim): def block_size(x, head_dim):
return min(65536 // x.element_size(), triton.next_power_of_2(head_dim)) return min(65536 // x.element_size(), triton.next_power_of_2(head_dim))

View File

@@ -183,7 +183,9 @@ class SpecDecodeBaseProposer:
RocmAttentionMetadata, RocmAttentionMetadata,
] ]
# ROCM_AITER_FA is an optional backend # ROCM_AITER_FA is an optional backend
if find_spec( from vllm._aiter_ops import rocm_aiter_ops
if rocm_aiter_ops.is_enabled() and find_spec(
AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False) AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False)
): ):
from vllm.v1.attention.backends.rocm_aiter_fa import ( from vllm.v1.attention.backends.rocm_aiter_fa import (