[ROCm][AITER] Fix AITER import regression for explicit backend selection (#33749)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -5,10 +5,17 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.v1.attention.backends.rocm_aiter_fa # noqa: F401
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_random_seed
|
||||
from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available
|
||||
|
||||
# Import AITER backend if on ROCm and aiter is available
|
||||
if current_platform.is_rocm():
|
||||
from vllm._aiter_ops import is_aiter_found_and_supported
|
||||
|
||||
if is_aiter_found_and_supported():
|
||||
import aiter
|
||||
|
||||
from vllm.v1.attention.backends.rocm_aiter_fa import cp_mha_gather_cache
|
||||
|
||||
NUM_HEADS = [(4, 4), (8, 2)]
|
||||
HEAD_SIZES = [128, 256]
|
||||
@@ -102,8 +109,11 @@ def test_varlen_with_paged_kv(
|
||||
num_blocks: int,
|
||||
q_dtype: torch.dtype | None,
|
||||
) -> None:
|
||||
if not is_flash_attn_varlen_func_available():
|
||||
pytest.skip("flash_attn_varlen_func required to run this test.")
|
||||
from vllm._aiter_ops import is_aiter_found_and_supported
|
||||
|
||||
if not is_aiter_found_and_supported():
|
||||
pytest.skip("aiter package required for this test.")
|
||||
|
||||
torch.set_default_device("cuda")
|
||||
set_random_seed(0)
|
||||
num_seqs = len(seq_lens)
|
||||
@@ -129,6 +139,8 @@ def test_varlen_with_paged_kv(
|
||||
cu_seq_lens = torch.tensor([0] + kv_lens, dtype=torch.int32).cumsum(
|
||||
dim=0, dtype=torch.int32
|
||||
)
|
||||
# Save kv_lens as list before converting to tensor
|
||||
kv_lens_list = kv_lens
|
||||
kv_lens = torch.tensor(kv_lens, dtype=torch.int32)
|
||||
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
@@ -141,33 +153,83 @@ def test_varlen_with_paged_kv(
|
||||
maybe_quantized_query = query
|
||||
maybe_quantized_key_cache = key_cache
|
||||
maybe_quantized_value_cache = value_cache
|
||||
k_descale = None
|
||||
v_descale = None
|
||||
k_scale_tensor = None
|
||||
v_scale_tensor = None
|
||||
dequant = False
|
||||
|
||||
if q_dtype is not None:
|
||||
# QKV are drawn from N(0, 1): no need for a fp8 scaling factor
|
||||
maybe_quantized_query = query.to(q_dtype)
|
||||
maybe_quantized_key_cache = key_cache.to(q_dtype)
|
||||
maybe_quantized_value_cache = value_cache.to(q_dtype)
|
||||
|
||||
dequant = True
|
||||
scale_shape = (num_seqs, num_kv_heads)
|
||||
k_descale = torch.ones(scale_shape, dtype=torch.float32)
|
||||
v_descale = torch.ones(scale_shape, dtype=torch.float32)
|
||||
|
||||
torch.ops.vllm.flash_attn_varlen_func(
|
||||
maybe_quantized_query,
|
||||
maybe_quantized_key_cache,
|
||||
maybe_quantized_value_cache,
|
||||
out=output,
|
||||
# For per-seq-per-head scales (matching AITER backend expectation)
|
||||
k_scale_tensor = torch.ones(scale_shape, dtype=torch.float32)
|
||||
v_scale_tensor = torch.ones(scale_shape, dtype=torch.float32)
|
||||
|
||||
# Prepare metadata for cp_mha_gather_cache
|
||||
# token_to_batch: maps each token to its batch index
|
||||
token_to_batch = torch.zeros(sum(kv_lens_list), dtype=torch.int32)
|
||||
seq_starts = torch.zeros(num_seqs, dtype=torch.int32)
|
||||
|
||||
token_idx = 0
|
||||
for batch_idx, kv_len in enumerate(kv_lens_list):
|
||||
token_to_batch[token_idx : token_idx + kv_len] = batch_idx
|
||||
seq_starts[batch_idx] = 0 # Assuming all sequences start at 0 in their blocks
|
||||
token_idx += kv_len
|
||||
|
||||
# Allocate buffers for gathered KV
|
||||
total_kv_tokens = sum(kv_lens_list)
|
||||
gathered_key = torch.empty(
|
||||
total_kv_tokens, num_kv_heads, head_size, dtype=maybe_quantized_key_cache.dtype
|
||||
)
|
||||
gathered_value = torch.empty(
|
||||
total_kv_tokens,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
dtype=maybe_quantized_value_cache.dtype,
|
||||
)
|
||||
|
||||
# Gather paged KV cache into contiguous tensors using triton kernel
|
||||
cp_mha_gather_cache(
|
||||
key_cache=maybe_quantized_key_cache,
|
||||
value_cache=maybe_quantized_value_cache,
|
||||
key=gathered_key,
|
||||
value=gathered_value,
|
||||
block_tables=block_tables,
|
||||
k_scales=k_scale_tensor
|
||||
if k_scale_tensor is not None
|
||||
else torch.ones(1, dtype=torch.float32),
|
||||
v_scales=v_scale_tensor
|
||||
if v_scale_tensor is not None
|
||||
else torch.ones(1, dtype=torch.float32),
|
||||
cu_seqlens_kv=cu_seq_lens,
|
||||
token_to_batch=token_to_batch,
|
||||
seq_starts=seq_starts,
|
||||
dequant=dequant,
|
||||
kv_cache_layout="NHD",
|
||||
total_tokens=total_kv_tokens,
|
||||
)
|
||||
|
||||
# Call aiter flash attention with gathered KV
|
||||
aiter.flash_attn_varlen_func(
|
||||
q=maybe_quantized_query,
|
||||
k=gathered_key,
|
||||
v=gathered_value,
|
||||
cu_seqlens_q=cu_query_lens,
|
||||
cu_seqlens_k=cu_seq_lens,
|
||||
max_seqlen_q=max_query_len,
|
||||
max_seqlen_k=max_kv_len,
|
||||
min_seqlen_q=1,
|
||||
dropout_p=0.0,
|
||||
softmax_scale=scale,
|
||||
alibi_slopes=None,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
block_table=block_tables,
|
||||
cu_seqlens_k=cu_seq_lens,
|
||||
k_scale=k_descale,
|
||||
v_scale=v_descale,
|
||||
alibi_slopes=None,
|
||||
return_lse=False,
|
||||
out=output,
|
||||
)
|
||||
|
||||
ref_output = ref_paged_attn(
|
||||
@@ -175,7 +237,7 @@ def test_varlen_with_paged_kv(
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
query_lens=query_lens,
|
||||
kv_lens=kv_lens,
|
||||
kv_lens=kv_lens_list,
|
||||
block_tables=block_tables,
|
||||
scale=scale,
|
||||
sliding_window=sliding_window,
|
||||
@@ -189,3 +251,8 @@ def test_varlen_with_paged_kv(
|
||||
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol),
|
||||
f"{torch.max(torch.abs(output - ref_output))}",
|
||||
)
|
||||
|
||||
# Log diff stats for tracking changes
|
||||
print(f"Max abs diff: {torch.max(torch.abs(output - ref_output))}")
|
||||
print(f"Mean diff: {torch.mean(torch.abs(output - ref_output))}")
|
||||
print(f"Min diff: {torch.std(torch.abs(output - ref_output))}")
|
||||
|
||||
@@ -14,7 +14,10 @@ from vllm.v1.attention.ops.rocm_aiter_mla_sparse import (
|
||||
rocm_aiter_sparse_attn_indexer_fake,
|
||||
)
|
||||
|
||||
_FP8_DTYPE = current_platform.fp8_dtype()
|
||||
# fp8_dtype is not cached.
|
||||
# on ROCm the fp8_dtype always calls is_fp8_fnuz
|
||||
# which is a host op, so we cache it once here.
|
||||
FP8_DTYPE = current_platform.fp8_dtype()
|
||||
|
||||
|
||||
def is_aiter_found() -> bool:
|
||||
@@ -31,12 +34,22 @@ IS_AITER_FOUND = is_aiter_found()
|
||||
|
||||
|
||||
def is_aiter_found_and_supported() -> bool:
|
||||
"""Check if AITER is available AND enabled via environment variable.
|
||||
"""Check if AITER library is available and platform supports it.
|
||||
|
||||
Checks: platform (ROCm), device arch (gfx9), library existence,
|
||||
and VLLM_ROCM_USE_AITER env variable.
|
||||
Checks: platform (ROCm), device arch (gfx9), and library existence.
|
||||
Does NOT check environment variables - that's handled by rocm_aiter_ops.is_enabled().
|
||||
|
||||
This function determines if aiter CAN be used, not if it SHOULD be used.
|
||||
|
||||
Separation of concerns:
|
||||
- This function: Can aiter work on this system? (platform + library availability)
|
||||
- rocm_aiter_ops.is_enabled(): Should aiter be used by default? (adds env var check)
|
||||
- Backend selection: Can explicitly request aiter regardless of env var
|
||||
|
||||
This allows explicit backend selection via attention_config to work even when
|
||||
VLLM_ROCM_USE_AITER=0, while preventing unwanted JIT warnings for auto-discovery.
|
||||
"""
|
||||
if current_platform.is_rocm() and IS_AITER_FOUND and envs.VLLM_ROCM_USE_AITER:
|
||||
if current_platform.is_rocm() and IS_AITER_FOUND:
|
||||
from vllm.platforms.rocm import on_gfx9
|
||||
|
||||
return on_gfx9()
|
||||
@@ -58,21 +71,6 @@ def if_aiter_supported(func: Callable) -> Callable:
|
||||
return wrapper
|
||||
|
||||
|
||||
# Can't use dtypes.fp8 directly inside an op
|
||||
# because it returns wrong result on gfx942.
|
||||
# This is a workaround to get the correct FP8 dtype.
|
||||
# This might because that the get_gfx() is wrapped as a custom op.
|
||||
if is_aiter_found_and_supported():
|
||||
from aiter import dtypes
|
||||
|
||||
AITER_FP8_DTYPE = dtypes.fp8
|
||||
else:
|
||||
# Placeholder when AITER is disabled - prevents NameError during module load.
|
||||
# Note: When AITER is disabled, ops are not registered, so fake implementations
|
||||
# referencing this variable won't actually be called at runtime.
|
||||
AITER_FP8_DTYPE = _FP8_DTYPE
|
||||
|
||||
|
||||
def _rocm_aiter_fused_moe_impl(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@@ -539,7 +537,7 @@ def _rocm_aiter_rmsnorm_fused_add_dynamic_quant_impl(
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
import aiter as rocm_aiter
|
||||
|
||||
assert quant_dtype in [torch.int8, _FP8_DTYPE]
|
||||
assert quant_dtype in [torch.int8, FP8_DTYPE]
|
||||
|
||||
y_scale = torch.empty(x.shape[0], 1, dtype=torch.float32, device=x.device)
|
||||
out = torch.empty(x.shape, dtype=quant_dtype, device=x.device)
|
||||
@@ -581,7 +579,7 @@ def _rocm_aiter_rmsnorm_fused_dynamic_quant_impl(
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
import aiter as rocm_aiter
|
||||
|
||||
assert quant_dtype in [torch.int8, _FP8_DTYPE]
|
||||
assert quant_dtype in [torch.int8, FP8_DTYPE]
|
||||
|
||||
y_scale = torch.empty(x.shape[0], 1, dtype=torch.float32, device=x.device)
|
||||
out = torch.empty(x.shape, dtype=quant_dtype, device=x.device)
|
||||
@@ -630,10 +628,10 @@ def _rocm_aiter_per_token_quant_impl(
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
from aiter.ops.quant import dynamic_per_token_scaled_quant
|
||||
|
||||
assert quant_dtype in [torch.int8, _FP8_DTYPE]
|
||||
assert quant_dtype in [torch.int8, FP8_DTYPE]
|
||||
|
||||
out_shape = x.shape
|
||||
out = torch.empty(x.shape, dtype=_FP8_DTYPE, device=x.device)
|
||||
out = torch.empty(x.shape, dtype=FP8_DTYPE, device=x.device)
|
||||
if scale is None:
|
||||
scale = torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device)
|
||||
dynamic_per_token_scaled_quant(
|
||||
@@ -653,7 +651,7 @@ def _rocm_aiter_per_token_quant_fake(
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
out_shape = x.shape
|
||||
return (
|
||||
torch.empty(x.shape, dtype=_FP8_DTYPE, device=x.device),
|
||||
torch.empty(x.shape, dtype=FP8_DTYPE, device=x.device),
|
||||
torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device),
|
||||
)
|
||||
|
||||
@@ -675,7 +673,7 @@ def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl(
|
||||
None,
|
||||
None,
|
||||
group_size=group_size,
|
||||
dtype_quant=AITER_FP8_DTYPE,
|
||||
dtype_quant=FP8_DTYPE,
|
||||
res1=residual,
|
||||
)
|
||||
return (
|
||||
@@ -695,7 +693,7 @@ def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake(
|
||||
M, N = x.shape
|
||||
scale_shape = (M, (N + group_size - 1) // group_size)
|
||||
return (
|
||||
torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device),
|
||||
torch.empty_like(x, dtype=FP8_DTYPE, device=x.device),
|
||||
torch.empty_like(residual, device=residual.device),
|
||||
torch.empty(scale_shape, dtype=torch.float32, device=x.device),
|
||||
)
|
||||
@@ -717,7 +715,7 @@ def _rocm_aiter_rmsnorm_fp8_group_quant_impl(
|
||||
None,
|
||||
None,
|
||||
group_size=group_size,
|
||||
dtype_quant=AITER_FP8_DTYPE,
|
||||
dtype_quant=FP8_DTYPE,
|
||||
res1=None,
|
||||
)
|
||||
return (x_quant, x_quant_scales)
|
||||
@@ -732,7 +730,7 @@ def _rocm_aiter_rmsnorm_fp8_group_quant_fake(
|
||||
M, N = x.shape
|
||||
scale_shape = (M, (N + group_size - 1) // group_size)
|
||||
return (
|
||||
torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device),
|
||||
torch.empty_like(x, dtype=FP8_DTYPE, device=x.device),
|
||||
torch.empty(scale_shape, dtype=torch.float32, device=x.device),
|
||||
)
|
||||
|
||||
@@ -745,7 +743,7 @@ def _rocm_aiter_group_fp8_quant_impl(
|
||||
from aiter import QuantType, get_hip_quant
|
||||
|
||||
aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
|
||||
return aiter_per1x128_quant(x.contiguous(), quant_dtype=AITER_FP8_DTYPE)
|
||||
return aiter_per1x128_quant(x.contiguous(), quant_dtype=FP8_DTYPE)
|
||||
|
||||
|
||||
def _rocm_aiter_group_fp8_quant_fake(
|
||||
@@ -753,7 +751,7 @@ def _rocm_aiter_group_fp8_quant_fake(
|
||||
group_size: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
M, N = x.shape
|
||||
x_fp8 = torch.empty((M, N), dtype=AITER_FP8_DTYPE, device=x.device)
|
||||
x_fp8 = torch.empty((M, N), dtype=FP8_DTYPE, device=x.device)
|
||||
out_bs = torch.empty(
|
||||
(
|
||||
M,
|
||||
@@ -775,7 +773,7 @@ def _rocm_aiter_act_mul_and_fp8_group_quant_impl(
|
||||
x,
|
||||
activation="silu",
|
||||
group_size=group_size,
|
||||
dtype_quant=AITER_FP8_DTYPE,
|
||||
dtype_quant=FP8_DTYPE,
|
||||
)
|
||||
|
||||
|
||||
@@ -786,7 +784,7 @@ def _rocm_aiter_act_mul_and_fp8_group_quant_fake(
|
||||
M, N = x.shape
|
||||
assert N % 2 == 0
|
||||
N_half = N // 2
|
||||
x_fp8 = torch.empty((M, N_half), dtype=AITER_FP8_DTYPE, device=x.device)
|
||||
x_fp8 = torch.empty((M, N_half), dtype=FP8_DTYPE, device=x.device)
|
||||
out_bs = torch.empty(
|
||||
(
|
||||
M,
|
||||
@@ -986,7 +984,7 @@ class rocm_aiter_ops:
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
def is_shuffle_kv_cache_enabled(cls) -> bool:
|
||||
return cls._AITER_ENABLED and cls._SHUFFLE_KV_CACHE_ENABLED
|
||||
return cls._SHUFFLE_KV_CACHE_ENABLED
|
||||
|
||||
@classmethod
|
||||
@if_aiter_supported
|
||||
@@ -1654,5 +1652,87 @@ class rocm_aiter_ops:
|
||||
|
||||
return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors)
|
||||
|
||||
@staticmethod
|
||||
def flash_attn_varlen_func(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
cu_seqlens_q: torch.Tensor,
|
||||
cu_seqlens_k: torch.Tensor,
|
||||
max_seqlen_q: int,
|
||||
max_seqlen_k: int,
|
||||
min_seqlen_q: int | None = None,
|
||||
dropout_p: float = 0.0,
|
||||
softmax_scale: float | None = None,
|
||||
causal: bool = False,
|
||||
window_size: tuple[int, int] | None = None,
|
||||
alibi_slopes: torch.Tensor | None = None,
|
||||
return_lse: bool = False,
|
||||
out: torch.Tensor | None = None,
|
||||
):
|
||||
"""
|
||||
Flash attention with variable length sequences.
|
||||
|
||||
This function is NOT wrapped with @is_aiter_supported decorator
|
||||
to allow explicit backend selection via attention_config to work
|
||||
even when VLLM_ROCM_USE_AITER=0.
|
||||
|
||||
Note: This performs lazy import of aiter.flash_attn_varlen_func
|
||||
"""
|
||||
from aiter import flash_attn_varlen_func
|
||||
|
||||
return flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
min_seqlen_q=min_seqlen_q,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
alibi_slopes=alibi_slopes,
|
||||
return_lse=return_lse,
|
||||
out=out,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def pa_fwd_asm(
|
||||
Q: torch.Tensor,
|
||||
K: torch.Tensor,
|
||||
V: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
block_tables_stride0: int,
|
||||
K_QScale: torch.Tensor,
|
||||
V_QScale: torch.Tensor,
|
||||
out_: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Paged attention forward pass using assembly kernel.
|
||||
|
||||
This function is NOT wrapped with @is_aiter_supported decorator
|
||||
to allow explicit backend selection via attention_config to work
|
||||
even when VLLM_ROCM_USE_AITER=0.
|
||||
|
||||
Note: This performs lazy import of aiter.pa_fwd_asm
|
||||
"""
|
||||
from aiter import pa_fwd_asm
|
||||
|
||||
return pa_fwd_asm(
|
||||
Q=Q,
|
||||
K=K,
|
||||
V=V,
|
||||
block_tables=block_tables,
|
||||
context_lens=context_lens,
|
||||
block_tables_stride0=block_tables_stride0,
|
||||
K_QScale=K_QScale,
|
||||
V_QScale=V_QScale,
|
||||
out_=out_,
|
||||
)
|
||||
|
||||
|
||||
rocm_aiter_ops.register_ops_once()
|
||||
|
||||
@@ -8,6 +8,12 @@ from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Track whether upstream flash-attn is available on ROCm.
|
||||
# Set during module initialization and never modified afterwards.
|
||||
# This module-level flag avoids repeated import attempts and ensures
|
||||
# consistent behavior (similar to IS_AITER_FOUND in _aiter_ops.py).
|
||||
_ROCM_FLASH_ATTN_AVAILABLE = False
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from vllm._custom_ops import reshape_and_cache_flash
|
||||
from vllm.vllm_flash_attn import ( # type: ignore[attr-defined]
|
||||
@@ -26,6 +32,9 @@ elif current_platform.is_xpu():
|
||||
elif current_platform.is_rocm():
|
||||
try:
|
||||
from flash_attn import flash_attn_varlen_func # type: ignore[no-redef]
|
||||
|
||||
# Mark that upstream flash-attn is available on ROCm
|
||||
_ROCM_FLASH_ATTN_AVAILABLE = True
|
||||
except ImportError:
|
||||
|
||||
def flash_attn_varlen_func(*args: Any, **kwargs: Any) -> Any: # type: ignore[no-redef,misc]
|
||||
@@ -34,6 +43,15 @@ elif current_platform.is_rocm():
|
||||
"to be installed. Please install flash-attn first."
|
||||
)
|
||||
|
||||
# ROCm doesn't use scheduler metadata (FA3 feature), provide stub
|
||||
def get_scheduler_metadata(*args: Any, **kwargs: Any) -> None: # type: ignore[misc]
|
||||
return None
|
||||
|
||||
# ROCm uses the C++ custom op for reshape_and_cache
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
reshape_and_cache_flash = ops.reshape_and_cache_flash
|
||||
|
||||
|
||||
def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
|
||||
# import here to avoid circular dependencies
|
||||
@@ -128,4 +146,30 @@ def flash_attn_supports_mla():
|
||||
|
||||
|
||||
def is_flash_attn_varlen_func_available() -> bool:
|
||||
return current_platform.is_cuda() or current_platform.is_xpu()
|
||||
"""Check if flash_attn_varlen_func is available.
|
||||
|
||||
This function determines whether the flash_attn_varlen_func imported at module
|
||||
level is a working implementation or a stub.
|
||||
|
||||
Platform-specific sources:
|
||||
- CUDA: vllm.vllm_flash_attn.flash_attn_varlen_func
|
||||
- XPU: ipex_ops.flash_attn_varlen_func
|
||||
- ROCm: upstream flash_attn.flash_attn_varlen_func (if available)
|
||||
|
||||
Note: This is separate from the AITER flash attention backend (rocm_aiter_fa.py)
|
||||
which uses rocm_aiter_ops.flash_attn_varlen_func. The condition to use AITER is
|
||||
handled separately via _aiter_ops.is_aiter_found_and_supported().
|
||||
|
||||
Returns:
|
||||
bool: True if a working flash_attn_varlen_func implementation is available.
|
||||
"""
|
||||
if current_platform.is_cuda() or current_platform.is_xpu():
|
||||
# CUDA and XPU always have flash_attn_varlen_func available
|
||||
return True
|
||||
|
||||
if current_platform.is_rocm():
|
||||
# Use the flag set during module import to check if
|
||||
# upstream flash-attn was successfully imported
|
||||
return _ROCM_FLASH_ATTN_AVAILABLE
|
||||
|
||||
return False
|
||||
|
||||
@@ -34,9 +34,6 @@ _CP_TOKENS_PER_ITER_ROCM = 32 * 1024
|
||||
if current_platform.is_rocm():
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
import aiter
|
||||
|
||||
def block_size(x, head_dim):
|
||||
return min(65536 // x.element_size(), triton.next_power_of_2(head_dim))
|
||||
|
||||
@@ -798,7 +795,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
total_tokens=swa_total_tokens,
|
||||
)
|
||||
|
||||
aiter.flash_attn_varlen_func(
|
||||
rocm_aiter_ops.flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key_fetched,
|
||||
v=value_fetched,
|
||||
@@ -848,7 +845,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
v_scale,
|
||||
)
|
||||
return
|
||||
out, lse = aiter.flash_attn_varlen_func(
|
||||
out, lse = rocm_aiter_ops.flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
@@ -895,7 +892,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
total_tokens=total_token_per_batch[chunk_idx],
|
||||
)
|
||||
|
||||
suf_out, suf_lse = aiter.flash_attn_varlen_func(
|
||||
suf_out, suf_lse = rocm_aiter_ops.flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key_fetched,
|
||||
v=value_fetched,
|
||||
@@ -1053,7 +1050,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
prefill_key = key[num_decode_tokens + num_extend_tokens :]
|
||||
prefill_value = value[num_decode_tokens + num_extend_tokens :]
|
||||
|
||||
aiter.flash_attn_varlen_func(
|
||||
rocm_aiter_ops.flash_attn_varlen_func(
|
||||
q=prefill_query,
|
||||
k=prefill_key,
|
||||
v=prefill_value,
|
||||
@@ -1159,7 +1156,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
)
|
||||
new_key_cache = key_cache.view_as(k_cache_template)
|
||||
new_value_cache = value_cache.view_as(v_cache_template)
|
||||
aiter.pa_fwd_asm(
|
||||
rocm_aiter_ops.pa_fwd_asm(
|
||||
Q=query[:num_decode_tokens],
|
||||
K=new_key_cache,
|
||||
V=new_value_cache,
|
||||
@@ -1188,6 +1185,10 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
device=output.device,
|
||||
)
|
||||
|
||||
# import so that aiter register the op to the namespace of
|
||||
# torch.ops.aiter
|
||||
import aiter # noqa: F401
|
||||
|
||||
torch.ops.aiter.paged_attention_v1(
|
||||
output[:num_decode_tokens],
|
||||
workspace_buffer,
|
||||
|
||||
@@ -222,9 +222,13 @@ class SpecDecodeBaseProposer:
|
||||
RocmAttentionMetadata,
|
||||
]
|
||||
# ROCM_AITER_FA is an optional backend
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
|
||||
if rocm_aiter_ops.is_enabled() and find_spec(
|
||||
# We check is_enabled() here to avoid importing the backend module during
|
||||
# auto-discovery when VLLM_ROCM_USE_AITER=0, which would trigger aiter
|
||||
# import and JIT compilation warnings. Explicit backend selection via
|
||||
# attention_config still works because the backend module is loaded
|
||||
# directly when selected, not through this auto-discovery path.
|
||||
# Check if backend module exists to allow explicit selection
|
||||
if find_spec(
|
||||
AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False)
|
||||
):
|
||||
from vllm.v1.attention.backends.rocm_aiter_fa import (
|
||||
|
||||
Reference in New Issue
Block a user