[ROCm][AITER] Fix AITER import regression for explicit backend selection (#33749)
Signed-off-by: Andreas Karatzas <akaratza@amd.com>
This commit is contained in:
@@ -34,9 +34,6 @@ _CP_TOKENS_PER_ITER_ROCM = 32 * 1024
|
||||
if current_platform.is_rocm():
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
if rocm_aiter_ops.is_enabled():
|
||||
import aiter
|
||||
|
||||
def block_size(x, head_dim):
|
||||
return min(65536 // x.element_size(), triton.next_power_of_2(head_dim))
|
||||
|
||||
@@ -798,7 +795,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
total_tokens=swa_total_tokens,
|
||||
)
|
||||
|
||||
aiter.flash_attn_varlen_func(
|
||||
rocm_aiter_ops.flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key_fetched,
|
||||
v=value_fetched,
|
||||
@@ -848,7 +845,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
v_scale,
|
||||
)
|
||||
return
|
||||
out, lse = aiter.flash_attn_varlen_func(
|
||||
out, lse = rocm_aiter_ops.flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
@@ -895,7 +892,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
total_tokens=total_token_per_batch[chunk_idx],
|
||||
)
|
||||
|
||||
suf_out, suf_lse = aiter.flash_attn_varlen_func(
|
||||
suf_out, suf_lse = rocm_aiter_ops.flash_attn_varlen_func(
|
||||
q=query,
|
||||
k=key_fetched,
|
||||
v=value_fetched,
|
||||
@@ -1053,7 +1050,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
prefill_key = key[num_decode_tokens + num_extend_tokens :]
|
||||
prefill_value = value[num_decode_tokens + num_extend_tokens :]
|
||||
|
||||
aiter.flash_attn_varlen_func(
|
||||
rocm_aiter_ops.flash_attn_varlen_func(
|
||||
q=prefill_query,
|
||||
k=prefill_key,
|
||||
v=prefill_value,
|
||||
@@ -1159,7 +1156,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
)
|
||||
new_key_cache = key_cache.view_as(k_cache_template)
|
||||
new_value_cache = value_cache.view_as(v_cache_template)
|
||||
aiter.pa_fwd_asm(
|
||||
rocm_aiter_ops.pa_fwd_asm(
|
||||
Q=query[:num_decode_tokens],
|
||||
K=new_key_cache,
|
||||
V=new_value_cache,
|
||||
@@ -1188,6 +1185,10 @@ class AiterFlashAttentionImpl(AttentionImpl):
|
||||
device=output.device,
|
||||
)
|
||||
|
||||
# import so that aiter register the op to the namespace of
|
||||
# torch.ops.aiter
|
||||
import aiter # noqa: F401
|
||||
|
||||
torch.ops.aiter.paged_attention_v1(
|
||||
output[:num_decode_tokens],
|
||||
workspace_buffer,
|
||||
|
||||
Reference in New Issue
Block a user