[ROCm] Attention selector reordering (#36702)
Signed-off-by: Gregory Shtrasberg <Gregory.Shtrasberg@amd.com> Signed-off-by: Micah Williamson <micah.williamson@amd.com> Co-authored-by: Micah Williamson <micah.williamson@amd.com>
This commit is contained in:
committed by
GitHub
parent
09c3dc9186
commit
189ddefbfd
@@ -282,7 +282,7 @@ apply_rocm_test_overrides() {
|
||||
|
||||
# --- LoRA: disable custom paged attention ---
|
||||
if [[ $cmds == *"pytest -v -s lora"* ]]; then
|
||||
cmds=${cmds//"pytest -v -s lora"/"VLLM_ROCM_CUSTOM_PAGED_ATTN=0 pytest -v -s lora"}
|
||||
cmds=${cmds//"pytest -v -s lora"/"pytest -v -s lora"}
|
||||
fi
|
||||
|
||||
# --- Kernel ignores ---
|
||||
|
||||
@@ -175,7 +175,7 @@ Priority is **1 = highest** (tried first).
|
||||
| `FLEX_ATTENTION` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16` | Any | Any | ❌ | ✅ | ❌ | Decoder, Encoder Only | Any |
|
||||
| `ROCM_AITER_FA` | | fp16, bf16 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | 16, 32 | 64, 128, 256 | ❌ | ❌ | ❌ | Decoder, Enc-Dec | N/A |
|
||||
| `ROCM_AITER_UNIFIED_ATTN` | | fp16, bf16 | `auto` | %16 | Any | ✅ | ✅ | ❌ | All | N/A |
|
||||
| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | ✅ | ✅ | ❌ | All | N/A |
|
||||
| `ROCM_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | 32, 64, 80, 96, 128, 160, 192, 224, 256 | ❌ | ✅ | ❌ | All | N/A |
|
||||
| `TREE_ATTN` | | fp16, bf16 | `auto`, `float16`, `bfloat16` | %16 | 32, 64, 96, 128, 160, 192, 224, 256 | ❌ | ❌ | ❌ | Decoder | Any |
|
||||
| `TRITON_ATTN` | | fp16, bf16, fp32 | `auto`, `float16`, `bfloat16`, `fp8`, `fp8_e4m3`, `fp8_e5m2` | %16 | Any | ✅ | ✅ | ❌ | All | Any |
|
||||
|
||||
|
||||
@@ -54,7 +54,7 @@ def mock_on_mi3xx():
|
||||
(
|
||||
{},
|
||||
None,
|
||||
AttentionBackendEnum.TRITON_ATTN.get_path(),
|
||||
AttentionBackendEnum.ROCM_ATTN.get_path(),
|
||||
),
|
||||
# Test Case 2: Explicit TRITON_ATTN backend
|
||||
(
|
||||
@@ -81,41 +81,24 @@ def mock_on_mi3xx():
|
||||
AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path(),
|
||||
),
|
||||
# Test Case 6: VLLM_ROCM_USE_AITER=1
|
||||
# (defaults to AITER FA when MHA not explicitly disabled)
|
||||
(
|
||||
{"VLLM_ROCM_USE_AITER": "1"},
|
||||
None,
|
||||
AttentionBackendEnum.ROCM_AITER_FA.get_path(),
|
||||
AttentionBackendEnum.ROCM_ATTN.get_path(),
|
||||
),
|
||||
# Test Case 7: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=1
|
||||
(
|
||||
{"VLLM_ROCM_USE_AITER": "1", "VLLM_ROCM_USE_AITER_MHA": "1"},
|
||||
None,
|
||||
AttentionBackendEnum.ROCM_AITER_FA.get_path(),
|
||||
),
|
||||
# Test Case 8: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION=1
|
||||
(
|
||||
{
|
||||
"VLLM_ROCM_USE_AITER": "1",
|
||||
"VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION": "1",
|
||||
},
|
||||
None,
|
||||
AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN.get_path(),
|
||||
),
|
||||
# Test Case 9: VLLM_ROCM_USE_AITER=1 + explicit TRITON_ATTN
|
||||
# Test Case 7: VLLM_ROCM_USE_AITER=1 + explicit TRITON_ATTN
|
||||
(
|
||||
{"VLLM_ROCM_USE_AITER": "1"},
|
||||
"TRITON_ATTN",
|
||||
AttentionBackendEnum.TRITON_ATTN.get_path(),
|
||||
),
|
||||
# Test Case 10: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=0
|
||||
# (explicitly disabled)
|
||||
# Test Case 8: VLLM_ROCM_USE_AITER=1 + VLLM_ROCM_USE_AITER_MHA=0
|
||||
(
|
||||
{"VLLM_ROCM_USE_AITER": "1", "VLLM_ROCM_USE_AITER_MHA": "0"},
|
||||
None,
|
||||
AttentionBackendEnum.TRITON_ATTN.get_path(),
|
||||
AttentionBackendEnum.ROCM_ATTN.get_path(),
|
||||
),
|
||||
# Test Case 11: VLLM_ROCM_USE_AITER=1 + explicit ROCM_ATTN
|
||||
# Test Case 9: VLLM_ROCM_USE_AITER=1 + explicit ROCM_ATTN
|
||||
(
|
||||
{"VLLM_ROCM_USE_AITER": "1"},
|
||||
"ROCM_ATTN",
|
||||
|
||||
@@ -50,9 +50,9 @@ def is_aiter_found_and_supported() -> bool:
|
||||
VLLM_ROCM_USE_AITER=0, while preventing unwanted JIT warnings for auto-discovery.
|
||||
"""
|
||||
if current_platform.is_rocm() and IS_AITER_FOUND:
|
||||
from vllm.platforms.rocm import on_gfx9
|
||||
from vllm.platforms.rocm import on_mi3xx
|
||||
|
||||
return on_gfx9()
|
||||
return on_mi3xx()
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@@ -117,7 +117,6 @@ if TYPE_CHECKING:
|
||||
VLLM_ROCM_USE_SKINNY_GEMM: bool = True
|
||||
VLLM_ROCM_FP8_PADDING: bool = True
|
||||
VLLM_ROCM_MOE_PADDING: bool = True
|
||||
VLLM_ROCM_CUSTOM_PAGED_ATTN: bool = True
|
||||
VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT: bool = False
|
||||
VLLM_ENABLE_V1_MULTIPROCESSING: bool = True
|
||||
VLLM_LOG_BATCHSIZE_INTERVAL: float = -1
|
||||
@@ -1001,10 +1000,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
"VLLM_ROCM_FP8_PADDING": lambda: bool(int(os.getenv("VLLM_ROCM_FP8_PADDING", "1"))),
|
||||
# Pad the weights for the moe kernel
|
||||
"VLLM_ROCM_MOE_PADDING": lambda: bool(int(os.getenv("VLLM_ROCM_MOE_PADDING", "1"))),
|
||||
# custom paged attention kernel for MI3* cards
|
||||
"VLLM_ROCM_CUSTOM_PAGED_ATTN": lambda: (
|
||||
os.getenv("VLLM_ROCM_CUSTOM_PAGED_ATTN", "True").lower() in ("true", "1")
|
||||
),
|
||||
# Whether to use the shuffled kv cache layout
|
||||
"VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT": lambda: (
|
||||
os.getenv("VLLM_ROCM_SHUFFLE_KV_CACHE_LAYOUT", "False").lower() in ("true", "1")
|
||||
|
||||
@@ -265,7 +265,6 @@ def use_rocm_custom_paged_attention(
|
||||
and (block_size == 16 or block_size == 32)
|
||||
and (gqa_ratio >= 1 and gqa_ratio <= 16)
|
||||
and max_seq_len <= 128 * 1024
|
||||
and (envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
|
||||
and sinks is None
|
||||
)
|
||||
|
||||
@@ -280,7 +279,6 @@ def use_rocm_custom_paged_attention(
|
||||
and max_seq_len <= 128 * 1024
|
||||
and alibi_slopes is None
|
||||
and kv_cache_dtype == "auto"
|
||||
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN
|
||||
and sinks is None
|
||||
)
|
||||
|
||||
@@ -311,7 +309,7 @@ def _get_backend_priorities(
|
||||
use_mla: bool,
|
||||
use_sparse: bool,
|
||||
) -> list[AttentionBackendEnum]:
|
||||
from vllm._aiter_ops import rocm_aiter_ops
|
||||
from vllm._aiter_ops import is_aiter_found_and_supported, rocm_aiter_ops
|
||||
|
||||
if use_sparse:
|
||||
return [AttentionBackendEnum.ROCM_AITER_MLA_SPARSE]
|
||||
@@ -328,28 +326,15 @@ def _get_backend_priorities(
|
||||
AttentionBackendEnum.TRITON_MLA,
|
||||
]
|
||||
|
||||
backends = []
|
||||
|
||||
# Priority 1: Check for AITER Unified Attention (must check before MHA)
|
||||
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION:
|
||||
backends.append(AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN)
|
||||
|
||||
# Priority 2: Check for AITER MHA (Flash Attention)
|
||||
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA:
|
||||
backends = [
|
||||
AttentionBackendEnum.ROCM_ATTN,
|
||||
]
|
||||
if rocm_aiter_ops.is_mha_enabled():
|
||||
backends.append(AttentionBackendEnum.ROCM_AITER_FA)
|
||||
|
||||
# Priority 3: Check for ROCM_ATTN (prefill-decode split)
|
||||
from vllm.config import get_current_vllm_config_or_none
|
||||
|
||||
vllm_config = get_current_vllm_config_or_none()
|
||||
if (
|
||||
vllm_config is not None
|
||||
and vllm_config.attention_config.use_prefill_decode_attention
|
||||
):
|
||||
backends.append(AttentionBackendEnum.ROCM_ATTN)
|
||||
|
||||
# Default: Triton Unified Attention
|
||||
if is_aiter_found_and_supported():
|
||||
backends.append(AttentionBackendEnum.ROCM_AITER_UNIFIED_ATTN)
|
||||
backends.append(AttentionBackendEnum.TRITON_ATTN)
|
||||
|
||||
return backends
|
||||
|
||||
|
||||
|
||||
@@ -195,7 +195,10 @@ class RocmAttentionBackend(AttentionBackend):
|
||||
|
||||
@classmethod
|
||||
def supports_sink(cls) -> bool:
|
||||
return True
|
||||
# ROCM custom attention kernel does not support sinks.
|
||||
# Callink this backend with sinks will cause it to fall back to the Triton
|
||||
# kernel, which is less efficient than the proper triton backends.
|
||||
return False
|
||||
|
||||
forward_includes_kv_cache_update: bool = False
|
||||
|
||||
|
||||
@@ -10,11 +10,14 @@
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import tl, triton
|
||||
|
||||
from .prefix_prefill import context_attention_fwd
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
float8_info = torch.finfo(current_platform.fp8_dtype())
|
||||
|
||||
|
||||
@@ -392,6 +395,10 @@ def chunked_prefill_paged_decode(
|
||||
fp8_out_scale=output_scale,
|
||||
)
|
||||
else:
|
||||
logger.warning_once(
|
||||
"Cannot use ROCm custom paged attention kernel,"
|
||||
" falling back to Triton implementation."
|
||||
)
|
||||
real_block_size = value_cache.shape[3]
|
||||
# The standard model directly uses the original block_size.
|
||||
# Non-standard 544 uses 32 to accommodate integer division logic.
|
||||
|
||||
Reference in New Issue
Block a user