[BUGFIX][ROCM] ViT FlashAttention on ROCm (no GFX9) and contiguous on qwen3vl ROCm TORCH_SDPA (#27190)
Signed-off-by: JartX <sagformas@epdcenter.es> Co-authored-by: tjtanaa <tunjian.tan@embeddedllm.com>
This commit is contained in:
@@ -205,12 +205,16 @@ class RocmPlatform(Platform):
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(cls, head_size: int, dtype: torch.dtype) -> "_Backend":
|
||||
from importlib.util import find_spec
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
|
||||
if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA and on_gfx9():
|
||||
return _Backend.ROCM_AITER_FA
|
||||
if on_gfx9():
|
||||
|
||||
if on_gfx9() and find_spec("flash_attn") is not None:
|
||||
return _Backend.FLASH_ATTN
|
||||
|
||||
return _Backend.TORCH_SDPA
|
||||
|
||||
@classmethod
|
||||
|
||||
Reference in New Issue
Block a user