[Multi Modal] Add FA3 in VIT (#24347)
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
@@ -7,7 +7,6 @@ from typing import Final, Generic, Optional, Protocol, TypeVar, Union
|
||||
import torch
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention.selector import get_env_variable_attn_backend
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
|
||||
@@ -68,17 +67,18 @@ def get_vision_encoder_info(
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
|
||||
def get_vit_attn_backend(head_size: int, dtype: torch.dtype) -> _Backend:
|
||||
"""
|
||||
Get the available attention backend for Vision Transformer.
|
||||
"""
|
||||
# TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn.
|
||||
# Lazy import to avoid circular dependency
|
||||
from vllm.attention.selector import get_env_variable_attn_backend
|
||||
|
||||
selected_backend: Optional[_Backend] = get_env_variable_attn_backend()
|
||||
if selected_backend is not None:
|
||||
return selected_backend
|
||||
|
||||
return current_platform.get_vit_attn_backend(support_fa)
|
||||
return current_platform.get_vit_attn_backend(head_size, dtype)
|
||||
|
||||
|
||||
def resolve_visual_encoder_outputs(
|
||||
@@ -122,4 +122,4 @@ def resolve_visual_encoder_outputs(
|
||||
uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1)
|
||||
if post_layer_norm is not None and uses_last_layer:
|
||||
hs_pool[-1] = post_layer_norm(encoder_outputs)
|
||||
return torch.cat(hs_pool, dim=-1)
|
||||
return torch.cat(hs_pool, dim=-1)
|
||||
|
||||
Reference in New Issue
Block a user