[FEAT][ROCm] Enable running Flash Attention as ViT attn backend for Qwen-VL models on ROCm platform. (#22069)
Signed-off-by: tjtanaavllm <tunjian.tan@amd.com> Signed-off-by: vllmellm <vllm.ellm@embeddedllm.com> Co-authored-by: tjtanaavllm <tunjian.tan@amd.com>
This commit is contained in:
@@ -173,6 +173,18 @@ class RocmPlatform(Platform):
|
||||
"quark", "ptpc_fp8"
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend:
|
||||
if support_fa:
|
||||
if (envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_MHA
|
||||
and on_gfx9()):
|
||||
# Note: AITER FA is only supported for Qwen-VL models.
|
||||
# TODO: Add support for other VL models in their model class.
|
||||
return _Backend.ROCM_AITER_FA
|
||||
if on_gfx9():
|
||||
return _Backend.FLASH_ATTN
|
||||
return _Backend.TORCH_SDPA
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
|
||||
kv_cache_dtype, block_size, use_v1,
|
||||
|
||||
Reference in New Issue
Block a user