[Multi Modal] Add FA3 in VIT (#24347)
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
@@ -13,6 +13,7 @@ from torch.nn import functional as F
|
||||
from transformers import Siglip2VisionConfig
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from vllm.attention.layer import check_upstream_fa_availability
|
||||
from vllm.config import QuantizationConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
@@ -236,7 +237,15 @@ class Siglip2Attention(nn.Module):
|
||||
self.use_rope = config.use_rope
|
||||
|
||||
# Detect attention implementation.
|
||||
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=self.head_dim, dtype=torch.get_default_dtype())
|
||||
self.use_upstream_fa = False
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||
check_upstream_fa_availability(
|
||||
torch.get_default_dtype()):
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
self.use_upstream_fa = True
|
||||
|
||||
if self.attn_backend not in {
|
||||
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA,
|
||||
_Backend.ROCM_AITER_FA
|
||||
@@ -280,7 +289,10 @@ class Siglip2Attention(nn.Module):
|
||||
if self.attn_backend == _Backend.ROCM_AITER_FA:
|
||||
from aiter import flash_attn_varlen_func
|
||||
else:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
if self.use_upstream_fa:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
else:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
attn_output = flash_attn_varlen_func(
|
||||
queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen,
|
||||
max_seqlen).reshape(seq_length, -1)
|
||||
|
||||
Reference in New Issue
Block a user