[Multi Modal] Add FA3 in VIT (#24347)

Signed-off-by: wwl2755 <wangwenlong2755@gmail.com>
This commit is contained in:
Wenlong Wang
2025-09-12 06:27:24 -07:00
committed by GitHub
parent fdb09c77d6
commit 72fc8aa412
13 changed files with 247 additions and 66 deletions

View File

@@ -44,6 +44,7 @@ from transformers.models.glm4v.video_processing_glm4v import (
Glm4vVideoProcessor)
from transformers.video_utils import VideoMetadata
from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig
from vllm.distributed import (get_tensor_model_parallel_world_size,
parallel_state)
@@ -260,7 +261,15 @@ class Glm4vVisionAttention(nn.Module):
)
# Detect attention implementation.
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
self.attn_backend = get_vit_attn_backend(
head_size=self.hidden_size_per_attention_head,
dtype=torch.get_default_dtype())
self.use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
self.use_upstream_fa = True
if self.attn_backend not in {
_Backend.FLASH_ATTN,
_Backend.TORCH_SDPA,
@@ -310,7 +319,10 @@ class Glm4vVisionAttention(nn.Module):
if self.attn_backend == _Backend.FLASH_ATTN:
# from vllm_flash_attn.flash_attn_interface import (
# flash_attn_varlen_func)
from flash_attn import flash_attn_varlen_func
if self.use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
@@ -715,7 +727,11 @@ class Glm4vVisionTransformer(nn.Module):
self.post_layernorm = RMSNorm(vision_config.hidden_size,
eps=vision_config.rms_norm_eps)
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim, dtype=torch.get_default_dtype())
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
@property
def dtype(self) -> torch.dtype: