[Multi Modal] Add FA3 in VIT (#24347)

Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
Wenlong Wang
2025-09-12 06:27:24 -07:00
committed by GitHub
parent fdb09c77d6
commit 72fc8aa412
13 changed files with 247 additions and 66 deletions

View File

@@ -7,7 +7,6 @@ from typing import Final, Generic, Optional, Protocol, TypeVar, Union
import torch
from transformers import PretrainedConfig
from vllm.attention.selector import get_env_variable_attn_backend
from vllm.logger import init_logger
from vllm.platforms import _Backend, current_platform
@@ -68,17 +67,18 @@ def get_vision_encoder_info(
raise NotImplementedError(msg)
def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
def get_vit_attn_backend(head_size: int, dtype: torch.dtype) -> _Backend:
"""
Get the available attention backend for Vision Transformer.
"""
# TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn.
# Lazy import to avoid circular dependency
from vllm.attention.selector import get_env_variable_attn_backend
selected_backend: Optional[_Backend] = get_env_variable_attn_backend()
if selected_backend is not None:
return selected_backend
return current_platform.get_vit_attn_backend(support_fa)
return current_platform.get_vit_attn_backend(head_size, dtype)
def resolve_visual_encoder_outputs(
@@ -122,4 +122,4 @@ def resolve_visual_encoder_outputs(
uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1)
if post_layer_norm is not None and uses_last_layer:
hs_pool[-1] = post_layer_norm(encoder_outputs)
return torch.cat(hs_pool, dim=-1)
return torch.cat(hs_pool, dim=-1)