[ROCm] Attention selector reordering (#36702)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Micah Williamson <micah.williamson@amd.com> Co-authored-by: Micah Williamson <micah.williamson@amd.com>
This commit is contained in:
committed by
GitHub
parent
09c3dc9186
commit
189ddefbfd
@@ -195,7 +195,10 @@ class RocmAttentionBackend(AttentionBackend):
|
||||
|
||||
@classmethod
|
||||
def supports_sink(cls) -> bool:
|
||||
return True
|
||||
# ROCM custom attention kernel does not support sinks.
|
||||
# Callink this backend with sinks will cause it to fall back to the Triton
|
||||
# kernel, which is less efficient than the proper triton backends.
|
||||
return False
|
||||
|
||||
forward_includes_kv_cache_update: bool = False
|
||||
|
||||
|
||||
@@ -10,11 +10,14 @@
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .prefix_prefill import context_attention_fwd
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
float8_info = torch.finfo(current_platform.fp8_dtype())
|
||||
|
||||
|
||||
@@ -392,6 +395,10 @@ def chunked_prefill_paged_decode(
|
||||
fp8_out_scale=output_scale,
|
||||
)
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Cannot use ROCm custom paged attention kernel,"
|
||||
" falling back to Triton implementation."
|
||||
)
|
||||
real_block_size = value_cache.shape[3]
|
||||
# The standard model directly uses the original block_size.
|
||||
# Non-standard 544 uses 32 to accommodate integer division logic.
|
||||
|
||||
Reference in New Issue
Block a user