[Attention] Refactor CUDA attention backend selection logic (#24794)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -42,7 +42,7 @@ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
|
||||
Qwen2_5_VLVisionConfig,
|
||||
)
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import maybe_get_vit_flash_attn_backend
|
||||
from vllm.attention.ops.vit_attn_wrappers import (
|
||||
vit_flash_attn_wrapper,
|
||||
@@ -315,9 +315,9 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend: _Backend = _Backend.TORCH_SDPA,
|
||||
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
|
||||
use_upstream_fa: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# Per attention head and per partition values.
|
||||
@@ -364,13 +364,16 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
# On ROCm with FLASH_ATTN backend, upstream flash_attn is used
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_rocm() and self.attn_backend == _Backend.FLASH_ATTN:
|
||||
if (
|
||||
current_platform.is_rocm()
|
||||
and self.attn_backend == AttentionBackendEnum.FLASH_ATTN
|
||||
):
|
||||
self.use_upstream_fa = True
|
||||
if current_platform.is_xpu():
|
||||
self.use_upstream_fa = False
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}
|
||||
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
@@ -431,10 +434,10 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
batch_size,
|
||||
self.attn_backend == _Backend.ROCM_AITER_FA,
|
||||
self.attn_backend == AttentionBackendEnum.ROCM_AITER_FA,
|
||||
self.use_upstream_fa,
|
||||
)
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
# Execute attention entry by entry for speed & less VRAM.
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -450,7 +453,7 @@ class Qwen2_5_VisionAttention(nn.Module):
|
||||
v,
|
||||
cu_seqlens,
|
||||
)
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
||||
context_layer = vit_xformers_attn_wrapper(q, k, v, seqlens)
|
||||
|
||||
output, _ = self.proj(context_layer)
|
||||
@@ -478,9 +481,9 @@ class Qwen2_5_VisionBlock(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend: _Backend = _Backend.TORCH_SDPA,
|
||||
attn_backend: AttentionBackendEnum = AttentionBackendEnum.TORCH_SDPA,
|
||||
use_upstream_fa: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if norm_layer is None:
|
||||
@@ -656,7 +659,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -708,10 +711,10 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
)
|
||||
|
||||
if self.attn_backend not in {
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.TORCH_SDPA,
|
||||
_Backend.XFORMERS,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.XFORMERS,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Qwen2.5-VL does not support {self.attn_backend} backend now."
|
||||
@@ -850,9 +853,12 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
max_seqlen = torch.zeros([], device=cu_seqlens.device)
|
||||
seqlens = torch.zeros(1, device=cu_seqlens.device)
|
||||
if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
|
||||
if self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
||||
seqlens = cu_seqlens[1:] - cu_seqlens[:-1]
|
||||
return max_seqlen, seqlens
|
||||
|
||||
|
||||
Reference in New Issue
Block a user