fix[ROCm]: Remove unconditional aiter import (#32902)
Signed-off-by: rabi <ramishra@redhat.com>
This commit is contained in:
@@ -31,7 +31,12 @@ IS_AITER_FOUND = is_aiter_found()
|
||||
|
||||
|
||||
def is_aiter_found_and_supported() -> bool:
|
||||
if current_platform.is_rocm() and IS_AITER_FOUND:
|
||||
"""Check if AITER is available AND enabled via environment variable.
|
||||
|
||||
Checks: platform (ROCm), device arch (gfx9), library existence,
|
||||
and VLLM_ROCM_USE_AITER env variable.
|
||||
"""
|
||||
if current_platform.is_rocm() and IS_AITER_FOUND and envs.VLLM_ROCM_USE_AITER:
|
||||
from vllm.platforms.rocm import on_gfx9
|
||||
|
||||
return on_gfx9()
|
||||
@@ -40,13 +45,11 @@ def is_aiter_found_and_supported() -> bool:
|
||||
|
||||
def if_aiter_supported(func: Callable) -> Callable:
|
||||
"""Decorator that only executes the function if
|
||||
ROCm AITER package is supported on gfx9 archs.
|
||||
ROCm AITER package is supported and enabled on gfx9 archs.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# checks the platform, device arch and aiter library existence.
|
||||
|
||||
if is_aiter_found_and_supported():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@@ -63,6 +66,11 @@ if is_aiter_found_and_supported():
|
||||
from aiter import dtypes
|
||||
|
||||
AITER_FP8_DTYPE = dtypes.fp8
|
||||
else:
|
||||
# Placeholder when AITER is disabled - prevents NameError during module load.
|
||||
# Note: When AITER is disabled, ops are not registered, so fake implementations
|
||||
# referencing this variable won't actually be called at runtime.
|
||||
AITER_FP8_DTYPE = _FP8_DTYPE
|
||||
|
||||
|
||||
def _rocm_aiter_fused_moe_impl(
|
||||
|
||||
@@ -32,10 +32,11 @@ from vllm.v1.kv_cache_interface import AttentionSpec
|
||||
_PARTITION_SIZE_ROCM = 256
|
||||
_CP_TOKENS_PER_ITER_ROCM = 32 * 1024
|
||||
if current_platform.is_rocm():
|
||||
import aiter
|
||||
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
import aiter
|
||||
|
||||
def block_size(x, head_dim):
|
||||
return min(65536 // x.element_size(), triton.next_power_of_2(head_dim))
|
||||
|
||||
|
||||
@@ -183,7 +183,9 @@ class SpecDecodeBaseProposer:
|
||||
RocmAttentionMetadata,
|
||||
]
|
||||
# ROCM_AITER_FA is an optional backend
|
||||
if find_spec(
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
if rocm_aiter_ops.is_enabled() and find_spec(
|
||||
AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False)
|
||||
):
|
||||
from vllm.v1.attention.backends.rocm_aiter_fa import (
|
||||
|
||||
Reference in New Issue
Block a user