[ROCm][AITER] Fix AITER import regression for explicit backend selection (#33749)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-02-06 09:08:16 -06:00
committed by GitHub
parent 1fb0495a72
commit 350ca72c04
5 changed files with 262 additions and 66 deletions

View File

@@ -5,10 +5,17 @@
import pytest
import torch
import vllm.v1.attention.backends.rocm_aiter_fa # noqa: F401
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed
from vllm.v1.attention.backends.fa_utils import is_flash_attn_varlen_func_available
# Import AITER backend if on ROCm and aiter is available
if current_platform.is_rocm():
from vllm._aiter_ops import is_aiter_found_and_supported
if is_aiter_found_and_supported():
import aiter
from vllm.v1.attention.backends.rocm_aiter_fa import cp_mha_gather_cache
NUM_HEADS = [(4, 4), (8, 2)]
HEAD_SIZES = [128, 256]
@@ -102,8 +109,11 @@ def test_varlen_with_paged_kv(
num_blocks: int,
q_dtype: torch.dtype | None,
) -> None:
if not is_flash_attn_varlen_func_available():
pytest.skip("flash_attn_varlen_func required to run this test.")
from vllm._aiter_ops import is_aiter_found_and_supported
if not is_aiter_found_and_supported():
pytest.skip("aiter package required for this test.")
torch.set_default_device("cuda")
set_random_seed(0)
num_seqs = len(seq_lens)
@@ -129,6 +139,8 @@ def test_varlen_with_paged_kv(
cu_seq_lens = torch.tensor([0] + kv_lens, dtype=torch.int32).cumsum(
dim=0, dtype=torch.int32
)
# Save kv_lens as list before converting to tensor
kv_lens_list = kv_lens
kv_lens = torch.tensor(kv_lens, dtype=torch.int32)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
@@ -141,33 +153,83 @@ def test_varlen_with_paged_kv(
maybe_quantized_query = query
maybe_quantized_key_cache = key_cache
maybe_quantized_value_cache = value_cache
k_descale = None
v_descale = None
k_scale_tensor = None
v_scale_tensor = None
dequant = False
if q_dtype is not None:
# QKV are drawn from N(0, 1): no need for a fp8 scaling factor
maybe_quantized_query = query.to(q_dtype)
maybe_quantized_key_cache = key_cache.to(q_dtype)
maybe_quantized_value_cache = value_cache.to(q_dtype)
dequant = True
scale_shape = (num_seqs, num_kv_heads)
k_descale = torch.ones(scale_shape, dtype=torch.float32)
v_descale = torch.ones(scale_shape, dtype=torch.float32)
torch.ops.vllm.flash_attn_varlen_func(
maybe_quantized_query,
maybe_quantized_key_cache,
maybe_quantized_value_cache,
out=output,
# For per-seq-per-head scales (matching AITER backend expectation)
k_scale_tensor = torch.ones(scale_shape, dtype=torch.float32)
v_scale_tensor = torch.ones(scale_shape, dtype=torch.float32)
# Prepare metadata for cp_mha_gather_cache
# token_to_batch: maps each token to its batch index
token_to_batch = torch.zeros(sum(kv_lens_list), dtype=torch.int32)
seq_starts = torch.zeros(num_seqs, dtype=torch.int32)
token_idx = 0
for batch_idx, kv_len in enumerate(kv_lens_list):
token_to_batch[token_idx : token_idx + kv_len] = batch_idx
seq_starts[batch_idx] = 0 # Assuming all sequences start at 0 in their blocks
token_idx += kv_len
# Allocate buffers for gathered KV
total_kv_tokens = sum(kv_lens_list)
gathered_key = torch.empty(
total_kv_tokens, num_kv_heads, head_size, dtype=maybe_quantized_key_cache.dtype
)
gathered_value = torch.empty(
total_kv_tokens,
num_kv_heads,
head_size,
dtype=maybe_quantized_value_cache.dtype,
)
# Gather paged KV cache into contiguous tensors using triton kernel
cp_mha_gather_cache(
key_cache=maybe_quantized_key_cache,
value_cache=maybe_quantized_value_cache,
key=gathered_key,
value=gathered_value,
block_tables=block_tables,
k_scales=k_scale_tensor
if k_scale_tensor is not None
else torch.ones(1, dtype=torch.float32),
v_scales=v_scale_tensor
if v_scale_tensor is not None
else torch.ones(1, dtype=torch.float32),
cu_seqlens_kv=cu_seq_lens,
token_to_batch=token_to_batch,
seq_starts=seq_starts,
dequant=dequant,
kv_cache_layout="NHD",
total_tokens=total_kv_tokens,
)
# Call aiter flash attention with gathered KV
aiter.flash_attn_varlen_func(
q=maybe_quantized_query,
k=gathered_key,
v=gathered_value,
cu_seqlens_q=cu_query_lens,
cu_seqlens_k=cu_seq_lens,
max_seqlen_q=max_query_len,
max_seqlen_k=max_kv_len,
min_seqlen_q=1,
dropout_p=0.0,
softmax_scale=scale,
alibi_slopes=None,
causal=True,
window_size=window_size,
block_table=block_tables,
cu_seqlens_k=cu_seq_lens,
k_scale=k_descale,
v_scale=v_descale,
alibi_slopes=None,
return_lse=False,
out=output,
)
ref_output = ref_paged_attn(
@@ -175,7 +237,7 @@ def test_varlen_with_paged_kv(
key_cache=key_cache,
value_cache=value_cache,
query_lens=query_lens,
kv_lens=kv_lens,
kv_lens=kv_lens_list,
block_tables=block_tables,
scale=scale,
sliding_window=sliding_window,
@@ -189,3 +251,8 @@ def test_varlen_with_paged_kv(
torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol),
f"{torch.max(torch.abs(output - ref_output))}",
)
# Log diff stats for tracking changes
print(f"Max abs diff: {torch.max(torch.abs(output - ref_output))}")
print(f"Mean diff: {torch.mean(torch.abs(output - ref_output))}")
print(f"Min diff: {torch.std(torch.abs(output - ref_output))}")

View File

@@ -14,7 +14,10 @@ from vllm.v1.attention.ops.rocm_aiter_mla_sparse import (
rocm_aiter_sparse_attn_indexer_fake,
)
_FP8_DTYPE = current_platform.fp8_dtype()
# fp8_dtype is not cached.
# on ROCm the fp8_dtype always calls is_fp8_fnuz
# which is a host op, so we cache it once here.
FP8_DTYPE = current_platform.fp8_dtype()
def is_aiter_found() -> bool:
@@ -31,12 +34,22 @@ IS_AITER_FOUND = is_aiter_found()
def is_aiter_found_and_supported() -> bool:
"""Check if AITER is available AND enabled via environment variable.
"""Check if AITER library is available and platform supports it.
Checks: platform (ROCm), device arch (gfx9), library existence,
and VLLM_ROCM_USE_AITER env variable.
Checks: platform (ROCm), device arch (gfx9), and library existence.
Does NOT check environment variables - that's handled by rocm_aiter_ops.is_enabled().
This function determines if aiter CAN be used, not if it SHOULD be used.
Separation of concerns:
- This function: Can aiter work on this system? (platform + library availability)
- rocm_aiter_ops.is_enabled(): Should aiter be used by default? (adds env var check)
- Backend selection: Can explicitly request aiter regardless of env var
This allows explicit backend selection via attention_config to work even when
VLLM_ROCM_USE_AITER=0, while preventing unwanted JIT warnings for auto-discovery.
"""
if current_platform.is_rocm() and IS_AITER_FOUND and envs.VLLM_ROCM_USE_AITER:
if current_platform.is_rocm() and IS_AITER_FOUND:
from vllm.platforms.rocm import on_gfx9
return on_gfx9()
@@ -58,21 +71,6 @@ def if_aiter_supported(func: Callable) -> Callable:
return wrapper
# Can't use dtypes.fp8 directly inside an op
# because it returns wrong result on gfx942.
# This is a workaround to get the correct FP8 dtype.
# This might because that the get_gfx() is wrapped as a custom op.
if is_aiter_found_and_supported():
from aiter import dtypes
AITER_FP8_DTYPE = dtypes.fp8
else:
# Placeholder when AITER is disabled - prevents NameError during module load.
# Note: When AITER is disabled, ops are not registered, so fake implementations
# referencing this variable won't actually be called at runtime.
AITER_FP8_DTYPE = _FP8_DTYPE
def _rocm_aiter_fused_moe_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
@@ -539,7 +537,7 @@ def _rocm_aiter_rmsnorm_fused_add_dynamic_quant_impl(
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
import aiter as rocm_aiter
assert quant_dtype in [torch.int8, _FP8_DTYPE]
assert quant_dtype in [torch.int8, FP8_DTYPE]
y_scale = torch.empty(x.shape[0], 1, dtype=torch.float32, device=x.device)
out = torch.empty(x.shape, dtype=quant_dtype, device=x.device)
@@ -581,7 +579,7 @@ def _rocm_aiter_rmsnorm_fused_dynamic_quant_impl(
) -> tuple[torch.Tensor, torch.Tensor]:
import aiter as rocm_aiter
assert quant_dtype in [torch.int8, _FP8_DTYPE]
assert quant_dtype in [torch.int8, FP8_DTYPE]
y_scale = torch.empty(x.shape[0], 1, dtype=torch.float32, device=x.device)
out = torch.empty(x.shape, dtype=quant_dtype, device=x.device)
@@ -630,10 +628,10 @@ def _rocm_aiter_per_token_quant_impl(
) -> tuple[torch.Tensor, torch.Tensor]:
from aiter.ops.quant import dynamic_per_token_scaled_quant
assert quant_dtype in [torch.int8, _FP8_DTYPE]
assert quant_dtype in [torch.int8, FP8_DTYPE]
out_shape = x.shape
out = torch.empty(x.shape, dtype=_FP8_DTYPE, device=x.device)
out = torch.empty(x.shape, dtype=FP8_DTYPE, device=x.device)
if scale is None:
scale = torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device)
dynamic_per_token_scaled_quant(
@@ -653,7 +651,7 @@ def _rocm_aiter_per_token_quant_fake(
) -> tuple[torch.Tensor, torch.Tensor]:
out_shape = x.shape
return (
torch.empty(x.shape, dtype=_FP8_DTYPE, device=x.device),
torch.empty(x.shape, dtype=FP8_DTYPE, device=x.device),
torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device),
)
@@ -675,7 +673,7 @@ def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_impl(
None,
None,
group_size=group_size,
dtype_quant=AITER_FP8_DTYPE,
dtype_quant=FP8_DTYPE,
res1=residual,
)
return (
@@ -695,7 +693,7 @@ def _rocm_aiter_rmsnorm_with_add_fp8_group_quant_fake(
M, N = x.shape
scale_shape = (M, (N + group_size - 1) // group_size)
return (
torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device),
torch.empty_like(x, dtype=FP8_DTYPE, device=x.device),
torch.empty_like(residual, device=residual.device),
torch.empty(scale_shape, dtype=torch.float32, device=x.device),
)
@@ -717,7 +715,7 @@ def _rocm_aiter_rmsnorm_fp8_group_quant_impl(
None,
None,
group_size=group_size,
dtype_quant=AITER_FP8_DTYPE,
dtype_quant=FP8_DTYPE,
res1=None,
)
return (x_quant, x_quant_scales)
@@ -732,7 +730,7 @@ def _rocm_aiter_rmsnorm_fp8_group_quant_fake(
M, N = x.shape
scale_shape = (M, (N + group_size - 1) // group_size)
return (
torch.empty_like(x, dtype=AITER_FP8_DTYPE, device=x.device),
torch.empty_like(x, dtype=FP8_DTYPE, device=x.device),
torch.empty(scale_shape, dtype=torch.float32, device=x.device),
)
@@ -745,7 +743,7 @@ def _rocm_aiter_group_fp8_quant_impl(
from aiter import QuantType, get_hip_quant
aiter_per1x128_quant = get_hip_quant(QuantType.per_1x128)
return aiter_per1x128_quant(x.contiguous(), quant_dtype=AITER_FP8_DTYPE)
return aiter_per1x128_quant(x.contiguous(), quant_dtype=FP8_DTYPE)
def _rocm_aiter_group_fp8_quant_fake(
@@ -753,7 +751,7 @@ def _rocm_aiter_group_fp8_quant_fake(
group_size: int,
) -> tuple[torch.Tensor, torch.Tensor]:
M, N = x.shape
x_fp8 = torch.empty((M, N), dtype=AITER_FP8_DTYPE, device=x.device)
x_fp8 = torch.empty((M, N), dtype=FP8_DTYPE, device=x.device)
out_bs = torch.empty(
(
M,
@@ -775,7 +773,7 @@ def _rocm_aiter_act_mul_and_fp8_group_quant_impl(
x,
activation="silu",
group_size=group_size,
dtype_quant=AITER_FP8_DTYPE,
dtype_quant=FP8_DTYPE,
)
@@ -786,7 +784,7 @@ def _rocm_aiter_act_mul_and_fp8_group_quant_fake(
M, N = x.shape
assert N % 2 == 0
N_half = N // 2
x_fp8 = torch.empty((M, N_half), dtype=AITER_FP8_DTYPE, device=x.device)
x_fp8 = torch.empty((M, N_half), dtype=FP8_DTYPE, device=x.device)
out_bs = torch.empty(
(
M,
@@ -986,7 +984,7 @@ class rocm_aiter_ops:
@classmethod
@if_aiter_supported
def is_shuffle_kv_cache_enabled(cls) -> bool:
return cls._AITER_ENABLED and cls._SHUFFLE_KV_CACHE_ENABLED
return cls._SHUFFLE_KV_CACHE_ENABLED
@classmethod
@if_aiter_supported
@@ -1654,5 +1652,87 @@ class rocm_aiter_ops:
return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors)
@staticmethod
def flash_attn_varlen_func(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
min_seqlen_q: int | None = None,
dropout_p: float = 0.0,
softmax_scale: float | None = None,
causal: bool = False,
window_size: tuple[int, int] | None = None,
alibi_slopes: torch.Tensor | None = None,
return_lse: bool = False,
out: torch.Tensor | None = None,
):
"""
Flash attention with variable length sequences.
This function is NOT wrapped with @is_aiter_supported decorator
to allow explicit backend selection via attention_config to work
even when VLLM_ROCM_USE_AITER=0.
Note: This performs lazy import of aiter.flash_attn_varlen_func
"""
from aiter import flash_attn_varlen_func
return flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
min_seqlen_q=min_seqlen_q,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
alibi_slopes=alibi_slopes,
return_lse=return_lse,
out=out,
)
@staticmethod
def pa_fwd_asm(
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
block_tables: torch.Tensor,
context_lens: torch.Tensor,
block_tables_stride0: int,
K_QScale: torch.Tensor,
V_QScale: torch.Tensor,
out_: torch.Tensor,
):
"""
Paged attention forward pass using assembly kernel.
This function is NOT wrapped with @is_aiter_supported decorator
to allow explicit backend selection via attention_config to work
even when VLLM_ROCM_USE_AITER=0.
Note: This performs lazy import of aiter.pa_fwd_asm
"""
from aiter import pa_fwd_asm
return pa_fwd_asm(
Q=Q,
K=K,
V=V,
block_tables=block_tables,
context_lens=context_lens,
block_tables_stride0=block_tables_stride0,
K_QScale=K_QScale,
V_QScale=V_QScale,
out_=out_,
)
rocm_aiter_ops.register_ops_once()

View File

@@ -8,6 +8,12 @@ from vllm.platforms import current_platform
logger = init_logger(__name__)
# Track whether upstream flash-attn is available on ROCm.
# Set during module initialization and never modified afterwards.
# This module-level flag avoids repeated import attempts and ensures
# consistent behavior (similar to IS_AITER_FOUND in _aiter_ops.py).
_ROCM_FLASH_ATTN_AVAILABLE = False
if current_platform.is_cuda():
from vllm._custom_ops import reshape_and_cache_flash
from vllm.vllm_flash_attn import ( # type: ignore[attr-defined]
@@ -26,6 +32,9 @@ elif current_platform.is_xpu():
elif current_platform.is_rocm():
try:
from flash_attn import flash_attn_varlen_func # type: ignore[no-redef]
# Mark that upstream flash-attn is available on ROCm
_ROCM_FLASH_ATTN_AVAILABLE = True
except ImportError:
def flash_attn_varlen_func(*args: Any, **kwargs: Any) -> Any: # type: ignore[no-redef,misc]
@@ -34,6 +43,15 @@ elif current_platform.is_rocm():
"to be installed. Please install flash-attn first."
)
# ROCm doesn't use scheduler metadata (FA3 feature), provide stub
def get_scheduler_metadata(*args: Any, **kwargs: Any) -> None: # type: ignore[misc]
return None
# ROCm uses the C++ custom op for reshape_and_cache
from vllm import _custom_ops as ops
reshape_and_cache_flash = ops.reshape_and_cache_flash
def get_flash_attn_version(requires_alibi: bool = False) -> int | None:
# import here to avoid circular dependencies
@@ -128,4 +146,30 @@ def flash_attn_supports_mla():
def is_flash_attn_varlen_func_available() -> bool:
return current_platform.is_cuda() or current_platform.is_xpu()
"""Check if flash_attn_varlen_func is available.
This function determines whether the flash_attn_varlen_func imported at module
level is a working implementation or a stub.
Platform-specific sources:
- CUDA: vllm.vllm_flash_attn.flash_attn_varlen_func
- XPU: ipex_ops.flash_attn_varlen_func
- ROCm: upstream flash_attn.flash_attn_varlen_func (if available)
Note: This is separate from the AITER flash attention backend (rocm_aiter_fa.py)
which uses rocm_aiter_ops.flash_attn_varlen_func. The condition to use AITER is
handled separately via _aiter_ops.is_aiter_found_and_supported().
Returns:
bool: True if a working flash_attn_varlen_func implementation is available.
"""
if current_platform.is_cuda() or current_platform.is_xpu():
# CUDA and XPU always have flash_attn_varlen_func available
return True
if current_platform.is_rocm():
# Use the flag set during module import to check if
# upstream flash-attn was successfully imported
return _ROCM_FLASH_ATTN_AVAILABLE
return False

View File

@@ -34,9 +34,6 @@ _CP_TOKENS_PER_ITER_ROCM = 32 * 1024
if current_platform.is_rocm():
from vllm.triton_utils import tl, triton
if rocm_aiter_ops.is_enabled():
import aiter
def block_size(x, head_dim):
return min(65536 // x.element_size(), triton.next_power_of_2(head_dim))
@@ -798,7 +795,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
total_tokens=swa_total_tokens,
)
aiter.flash_attn_varlen_func(
rocm_aiter_ops.flash_attn_varlen_func(
q=query,
k=key_fetched,
v=value_fetched,
@@ -848,7 +845,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
v_scale,
)
return
out, lse = aiter.flash_attn_varlen_func(
out, lse = rocm_aiter_ops.flash_attn_varlen_func(
q=query,
k=key,
v=value,
@@ -895,7 +892,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
total_tokens=total_token_per_batch[chunk_idx],
)
suf_out, suf_lse = aiter.flash_attn_varlen_func(
suf_out, suf_lse = rocm_aiter_ops.flash_attn_varlen_func(
q=query,
k=key_fetched,
v=value_fetched,
@@ -1053,7 +1050,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
prefill_key = key[num_decode_tokens + num_extend_tokens :]
prefill_value = value[num_decode_tokens + num_extend_tokens :]
aiter.flash_attn_varlen_func(
rocm_aiter_ops.flash_attn_varlen_func(
q=prefill_query,
k=prefill_key,
v=prefill_value,
@@ -1159,7 +1156,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
)
new_key_cache = key_cache.view_as(k_cache_template)
new_value_cache = value_cache.view_as(v_cache_template)
aiter.pa_fwd_asm(
rocm_aiter_ops.pa_fwd_asm(
Q=query[:num_decode_tokens],
K=new_key_cache,
V=new_value_cache,
@@ -1188,6 +1185,10 @@ class AiterFlashAttentionImpl(AttentionImpl):
device=output.device,
)
# import so that aiter register the op to the namespace of
# torch.ops.aiter
import aiter # noqa: F401
torch.ops.aiter.paged_attention_v1(
output[:num_decode_tokens],
workspace_buffer,

View File

@@ -222,9 +222,13 @@ class SpecDecodeBaseProposer:
RocmAttentionMetadata,
]
# ROCM_AITER_FA is an optional backend
from vllm._aiter_ops import rocm_aiter_ops
if rocm_aiter_ops.is_enabled() and find_spec(
# We check is_enabled() here to avoid importing the backend module during
# auto-discovery when VLLM_ROCM_USE_AITER=0, which would trigger aiter
# import and JIT compilation warnings. Explicit backend selection via
# attention_config still works because the backend module is loaded
# directly when selected, not through this auto-discovery path.
# Check if backend module exists to allow explicit selection
if find_spec(
AttentionBackendEnum.ROCM_AITER_FA.get_path(include_classname=False)
):
from vllm.v1.attention.backends.rocm_aiter_fa import (