[ROCm][AITER] Fix AITER import regression for explicit backend selection (#33749)

Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
Andreas Karatzas
2026-02-06 09:08:16 -06:00
committed by GitHub
parent 1fb0495a72
commit 350ca72c04
5 changed files with 262 additions and 66 deletions

View File

@@ -34,9 +34,6 @@ _CP_TOKENS_PER_ITER_ROCM = 32 * 1024
if current_platform.is_rocm():
from vllm.triton_utils import tl, triton
if rocm_aiter_ops.is_enabled():
import aiter
def block_size(x, head_dim):
return min(65536 // x.element_size(), triton.next_power_of_2(head_dim))
@@ -798,7 +795,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
total_tokens=swa_total_tokens,
)
aiter.flash_attn_varlen_func(
rocm_aiter_ops.flash_attn_varlen_func(
q=query,
k=key_fetched,
v=value_fetched,
@@ -848,7 +845,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
v_scale,
)
return
out, lse = aiter.flash_attn_varlen_func(
out, lse = rocm_aiter_ops.flash_attn_varlen_func(
q=query,
k=key,
v=value,
@@ -895,7 +892,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
total_tokens=total_token_per_batch[chunk_idx],
)
suf_out, suf_lse = aiter.flash_attn_varlen_func(
suf_out, suf_lse = rocm_aiter_ops.flash_attn_varlen_func(
q=query,
k=key_fetched,
v=value_fetched,
@@ -1053,7 +1050,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
prefill_key = key[num_decode_tokens + num_extend_tokens :]
prefill_value = value[num_decode_tokens + num_extend_tokens :]
aiter.flash_attn_varlen_func(
rocm_aiter_ops.flash_attn_varlen_func(
q=prefill_query,
k=prefill_key,
v=prefill_value,
@@ -1159,7 +1156,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
)
new_key_cache = key_cache.view_as(k_cache_template)
new_value_cache = value_cache.view_as(v_cache_template)
aiter.pa_fwd_asm(
rocm_aiter_ops.pa_fwd_asm(
Q=query[:num_decode_tokens],
K=new_key_cache,
V=new_value_cache,
@@ -1188,6 +1185,10 @@ class AiterFlashAttentionImpl(AttentionImpl):
device=output.device,
)
# import so that aiter register the op to the namespace of
# torch.ops.aiter
import aiter # noqa: F401
torch.ops.aiter.paged_attention_v1(
output[:num_decode_tokens],
workspace_buffer,