fix[ROCm]: Remove unconditional aiter import (#32902)
Signed-off-by: rabi <ramishra@redhat.com>
This commit is contained in:
@@ -31,7 +31,12 @@ IS_AITER_FOUND = is_aiter_found()
|
|||||||
|
|
||||||
|
|
||||||
def is_aiter_found_and_supported() -> bool:
|
def is_aiter_found_and_supported() -> bool:
|
||||||
if current_platform.is_rocm() and IS_AITER_FOUND:
|
"""Check if AITER is available AND enabled via environment variable.
|
||||||
|
|
||||||
|
Checks: platform (ROCm), device arch (gfx9), library existence,
|
||||||
|
and VLLM_ROCM_USE_AITER env variable.
|
||||||
|
"""
|
||||||
|
if current_platform.is_rocm() and IS_AITER_FOUND and envs.VLLM_ROCM_USE_AITER:
|
||||||
from vllm.platforms.rocm import on_gfx9
|
from vllm.platforms.rocm import on_gfx9
|
||||||
|
|
||||||
return on_gfx9()
|
return on_gfx9()
|
||||||
@@ -40,13 +45,11 @@ def is_aiter_found_and_supported() -> bool:
|
|||||||
|
|
||||||
def if_aiter_supported(func: Callable) -> Callable:
|
def if_aiter_supported(func: Callable) -> Callable:
|
||||||
"""Decorator that only executes the function if
|
"""Decorator that only executes the function if
|
||||||
ROCm AITER package is supported on gfx9 archs.
|
ROCm AITER package is supported and enabled on gfx9 archs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
# checks the platform, device arch and aiter library existence.
|
|
||||||
|
|
||||||
if is_aiter_found_and_supported():
|
if is_aiter_found_and_supported():
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
@@ -63,6 +66,11 @@ if is_aiter_found_and_supported():
|
|||||||
from aiter import dtypes
|
from aiter import dtypes
|
||||||
|
|
||||||
AITER_FP8_DTYPE = dtypes.fp8
|
AITER_FP8_DTYPE = dtypes.fp8
|
||||||
|
else:
|
||||||
|
# Placeholder when AITER is disabled - prevents NameError during module load.
|
||||||
|
# Note: When AITER is disabled, ops are not registered, so fake implementations
|
||||||
|
# referencing this variable won't actually be called at runtime.
|
||||||
|
AITER_FP8_DTYPE = _FP8_DTYPE
|
||||||
|
|
||||||
|
|
||||||
def _rocm_aiter_fused_moe_impl(
|
def _rocm_aiter_fused_moe_impl(
|
||||||
|
|||||||
@@ -32,10 +32,11 @@ from vllm.v1.kv_cache_interface import AttentionSpec
|
|||||||
_PARTITION_SIZE_ROCM = 256
|
_PARTITION_SIZE_ROCM = 256
|
||||||
_CP_TOKENS_PER_ITER_ROCM = 32 * 1024
|
_CP_TOKENS_PER_ITER_ROCM = 32 * 1024
|
||||||
if current_platform.is_rocm():
|
if current_platform.is_rocm():
|
||||||
import aiter
|
|
||||||
|
|
||||||
from vllm.triton_utils import tl, triton
|
from vllm.triton_utils import tl, triton
|
||||||
|
|
||||||
|
if rocm_aiter_ops.is_enabled():
|
||||||
|
import aiter
|
||||||
|
|
||||||
def block_size(x, head_dim):
|
def block_size(x, head_dim):
|
||||||
return min(65536 // x.element_size(), triton.next_power_of_2(head_dim))
|
return min(65536 // x.element_size(), triton.next_power_of_2(head_dim))
|
||||||
|
|
||||||
|
|||||||
@@ -183,7 +183,9 @@ class SpecDecodeBaseProposer:
|
|||||||
RocmAttentionMetadata,
|
RocmAttentionMetadata,
|
||||||
]
|
]
|
||||||
# ROCM_AITER_FA is an optional backend
|
# ROCM_AITER_FA is an optional backend
|
||||||
if find_spec(
|
from vllm._aiter_ops import rocm_aiter_ops
|
||||||
|
|
||||||
|
if rocm_aiter_ops.is_enabled() and find_spec(
|
||||||
AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False)
|
AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False)
|
||||||
):
|
):
|
||||||
from vllm.v1.attention.backends.rocm_aiter_fa import (
|
from vllm.v1.attention.backends.rocm_aiter_fa import (
|
||||||
|
|||||||
Reference in New Issue
Block a user