[Attention] Refactor CUDA attention backend selection logic (#24794)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: Matthew Bonanni <mbonanni001@gmail.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@@ -43,7 +43,7 @@ from transformers.models.qwen2_vl.configuration_qwen2_vl import (
|
||||
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
|
||||
from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor
|
||||
|
||||
from vllm.attention.backends.registry import _Backend
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import (
|
||||
check_upstream_fa_availability,
|
||||
maybe_get_vit_flash_attn_backend,
|
||||
@@ -329,7 +329,7 @@ class Qwen2VisionAttention(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# Per attention head and per partition values.
|
||||
@@ -378,18 +378,18 @@ class Qwen2VisionAttention(nn.Module):
|
||||
)
|
||||
|
||||
if self.attn_backend not in {
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.TORCH_SDPA,
|
||||
_Backend.XFORMERS,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.TORCH_SDPA,
|
||||
AttentionBackendEnum.XFORMERS,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
raise RuntimeError(
|
||||
f"Qwen2-VL does not support {self.attn_backend} backend now."
|
||||
)
|
||||
|
||||
self.is_flash_attn_backend = self.attn_backend in {
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.ROCM_AITER_FA,
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}
|
||||
|
||||
def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
|
||||
@@ -460,7 +460,7 @@ class Qwen2VisionAttention(nn.Module):
|
||||
context_layer = rearrange(
|
||||
output, "(b s) h d -> s b (h d)", b=batch_size
|
||||
).contiguous()
|
||||
elif self.attn_backend == _Backend.TORCH_SDPA:
|
||||
elif self.attn_backend == AttentionBackendEnum.TORCH_SDPA:
|
||||
# Execute attention entry by entry for speed & less VRAM.
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@@ -485,7 +485,7 @@ class Qwen2VisionAttention(nn.Module):
|
||||
context_layer = rearrange(
|
||||
context_layer, "b s h d -> s b (h d)"
|
||||
).contiguous()
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
||||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalMask
|
||||
|
||||
@@ -515,7 +515,7 @@ class Qwen2VisionBlock(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if norm_layer is None:
|
||||
@@ -679,7 +679,7 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: _Backend | None = None,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -739,10 +739,11 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
dtype=torch.get_default_dtype(),
|
||||
attn_backend_override=attn_backend_override,
|
||||
)
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
|
||||
torch.get_default_dtype()
|
||||
if (
|
||||
self.attn_backend != AttentionBackendEnum.FLASH_ATTN
|
||||
and check_upstream_fa_availability(torch.get_default_dtype())
|
||||
):
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
self.attn_backend = AttentionBackendEnum.FLASH_ATTN
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
@@ -789,9 +790,12 @@ class Qwen2VisionTransformer(nn.Module):
|
||||
self, cu_seqlens: torch.Tensor
|
||||
) -> tuple[int | None, list[int] | None]:
|
||||
max_seqlen, seqlens = None, None
|
||||
if self.attn_backend in {_Backend.FLASH_ATTN, _Backend.ROCM_AITER_FA}:
|
||||
if self.attn_backend in {
|
||||
AttentionBackendEnum.FLASH_ATTN,
|
||||
AttentionBackendEnum.ROCM_AITER_FA,
|
||||
}:
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
elif self.attn_backend == _Backend.XFORMERS:
|
||||
elif self.attn_backend == AttentionBackendEnum.XFORMERS:
|
||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
return max_seqlen, seqlens
|
||||
|
||||
|
||||
Reference in New Issue
Block a user