[Multi Modal] Add FA3 in VIT (#24347)

Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
Wenlong Wang
2025-09-12 06:27:24 -07:00
committed by GitHub
parent fdb09c77d6
commit 72fc8aa412
13 changed files with 247 additions and 66 deletions

View File

@@ -13,6 +13,7 @@ from torch.nn import functional as F
from transformers import Siglip2VisionConfig
from transformers.configuration_utils import PretrainedConfig
from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import QuantizationConfig
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
@@ -236,7 +237,15 @@ class Siglip2Attention(nn.Module):
self.use_rope = config.use_rope
# Detect attention implementation.
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
self.attn_backend = get_vit_attn_backend(
head_size=self.head_dim, dtype=torch.get_default_dtype())
self.use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(
torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
self.use_upstream_fa = True
if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA,
_Backend.ROCM_AITER_FA
@@ -280,7 +289,10 @@ class Siglip2Attention(nn.Module):
if self.attn_backend == _Backend.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
from flash_attn import flash_attn_varlen_func
if self.use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
attn_output = flash_attn_varlen_func(
queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen,
max_seqlen).reshape(seq_length, -1)