[ROCm] Refactor ROCm attention backend selection logic (#35246)
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
@@ -12,6 +12,7 @@ from vllm.config import VllmConfig, get_layers_from_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.platforms.interface import DeviceCapability
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.platform_utils import num_compute_units
|
||||
from vllm.v1.attention.backend import (
|
||||
@@ -766,6 +767,15 @@ class AiterFlashAttentionBackend(AttentionBackend):
|
||||
raise ValueError("Block size must be a multiple of 16.")
|
||||
return (2, num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@classmethod
|
||||
def supports_compute_capability(cls, capability: DeviceCapability) -> bool:
|
||||
from vllm.platforms.rocm import on_mi3xx
|
||||
|
||||
# DeviceCapability is currently created using torch.cuda.get_device_capability()
|
||||
# which is known to be buggy on rocm systems. on_mi3xx uses amd-smi which is
|
||||
# more reliable.
|
||||
return on_mi3xx()
|
||||
|
||||
|
||||
class AiterFlashAttentionImpl(AttentionImpl):
|
||||
def __init__(
|
||||
|
||||
Reference in New Issue
Block a user