[Multi Modal] Add FA3 in VIT (#24347)
Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
@@ -44,6 +44,7 @@ from transformers.models.glm4v.video_processing_glm4v import (
|
||||
Glm4vVideoProcessor)
|
||||
from transformers.video_utils import VideoMetadata
|
||||
|
||||
from vllm.attention.layer import check_upstream_fa_availability
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
||||
parallel_state)
|
||||
@@ -260,7 +261,15 @@ class Glm4vVisionAttention(nn.Module):
|
||||
)
|
||||
|
||||
# Detect attention implementation.
|
||||
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=self.hidden_size_per_attention_head,
|
||||
dtype=torch.get_default_dtype())
|
||||
self.use_upstream_fa = False
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||
check_upstream_fa_availability(torch.get_default_dtype()):
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
self.use_upstream_fa = True
|
||||
|
||||
if self.attn_backend not in {
|
||||
_Backend.FLASH_ATTN,
|
||||
_Backend.TORCH_SDPA,
|
||||
@@ -310,7 +319,10 @@ class Glm4vVisionAttention(nn.Module):
|
||||
if self.attn_backend == _Backend.FLASH_ATTN:
|
||||
# from vllm_flash_attn.flash_attn_interface import (
|
||||
# flash_attn_varlen_func)
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
if self.use_upstream_fa:
|
||||
from flash_attn import flash_attn_varlen_func
|
||||
else:
|
||||
from vllm.vllm_flash_attn import flash_attn_varlen_func
|
||||
|
||||
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
|
||||
|
||||
@@ -715,7 +727,11 @@ class Glm4vVisionTransformer(nn.Module):
|
||||
self.post_layernorm = RMSNorm(vision_config.hidden_size,
|
||||
eps=vision_config.rms_norm_eps)
|
||||
|
||||
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
|
||||
self.attn_backend = get_vit_attn_backend(
|
||||
head_size=head_dim, dtype=torch.get_default_dtype())
|
||||
if self.attn_backend != _Backend.FLASH_ATTN and \
|
||||
check_upstream_fa_availability(torch.get_default_dtype()):
|
||||
self.attn_backend = _Backend.FLASH_ATTN
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
|
||||
Reference in New Issue
Block a user