[MM Encoder]: Migrate legacy ViT MultiHeadAttention to new MMEncoderAttention interface (#30684)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py
2025-12-19 02:04:19 +08:00
committed by GitHub
parent 62be3670cb
commit 700a5ad6c6
20 changed files with 182 additions and 266 deletions

View File

@@ -8,7 +8,7 @@ from collections.abc import Iterable
import torch
import torch.nn as nn
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.utils import divide
from vllm.model_executor.layers.activation import SiluAndMul
@@ -126,7 +126,7 @@ class AIMv2Attention(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.attn = MultiHeadAttention(
self.attn = MMEncoderAttention(
self.num_heads_per_partition, self.head_dim, self.scale
)

View File

@@ -9,7 +9,7 @@ import torch
import torch.nn as nn
from transformers import Blip2VisionConfig, BlipVisionConfig
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
@@ -122,7 +122,7 @@ class BlipAttention(nn.Module):
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
self.attn = MultiHeadAttention(
self.attn = MMEncoderAttention(
self.num_heads_per_partition, self.head_dim, self.scale
)

View File

@@ -14,7 +14,8 @@ from transformers import (
CLIPVisionConfig,
)
from vllm.attention.layer import Attention, MultiHeadAttention
from vllm.attention.layer import Attention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import divide, get_tensor_model_parallel_world_size
@@ -354,7 +355,7 @@ class CLIPAttention(nn.Module):
quant_config: QuantizationConfig | None = None,
*,
prefix: str = "",
attn_cls: type[Attention] | type[MultiHeadAttention],
attn_cls: type[Attention] | type[MMEncoderAttention],
) -> None:
super().__init__()
@@ -449,7 +450,7 @@ class CLIPEncoderLayer(nn.Module):
quant_config: QuantizationConfig | None = None,
*,
prefix: str = "",
attn_cls: type[Attention] | type[MultiHeadAttention],
attn_cls: type[Attention] | type[MMEncoderAttention],
) -> None:
super().__init__()
self.self_attn = CLIPAttention(
@@ -493,7 +494,7 @@ class CLIPEncoder(nn.Module):
num_hidden_layers_override: int | None = None,
*,
prefix: str = "",
attn_cls: type[Attention] | type[MultiHeadAttention],
attn_cls: type[Attention] | type[MMEncoderAttention],
) -> None:
super().__init__()
@@ -638,7 +639,7 @@ class CLIPVisionTransformer(nn.Module):
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
attn_cls=MultiHeadAttention,
attn_cls=MMEncoderAttention,
)
num_hidden_layers = config.num_hidden_layers

View File

@@ -18,7 +18,7 @@ import torch.nn as nn
import torch.nn.functional as F
from transformers import CLIPVisionConfig
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
@@ -628,7 +628,7 @@ class DeepCLIPVisionTransformer(nn.Module):
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
attn_cls=MultiHeadAttention,
attn_cls=MMEncoderAttention,
)
num_hidden_layers = config.num_hidden_layers

View File

@@ -19,7 +19,7 @@ from transformers import BatchFeature, PreTrainedTokenizer, TensorType
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
@@ -135,7 +135,7 @@ class EVA2CLIPAttention(nn.Module):
prefix=f"{prefix}.dense",
)
self.attn = MultiHeadAttention(
self.attn = MMEncoderAttention(
self.num_heads_per_rank, self.head_dim, self.scale
)
self.output_dropout = torch.nn.Dropout(config.dropout_prob)

View File

@@ -34,7 +34,7 @@ import torch.nn.functional as F
from transformers import BatchFeature
from vllm.attention.backends.registry import AttentionBackendEnum
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import MultiModalConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import parallel_state
@@ -232,7 +232,7 @@ class HunYuanVisionAttention(nn.Module):
)
self.scale = self.hidden_size_per_attention_head**-0.5
self.attn = MultiHeadAttention(
self.attn = MMEncoderAttention(
self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head,
self.scale,

View File

@@ -27,7 +27,7 @@ from transformers.models.idefics2.configuration_idefics2 import (
Idefics2VisionConfig,
)
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
@@ -161,8 +161,8 @@ class Idefics2VisionAttention(nn.Module):
prefix=f"{prefix}.out_proj",
disable_tp=use_data_parallel,
)
# Use unified MultiHeadAttention with Flash Attention support
self.attn = MultiHeadAttention(
# Use unified MMEncoderAttention with Flash Attention support
self.attn = MMEncoderAttention(
self.num_heads_per_partition, self.head_dim, self.scale
)
@@ -175,7 +175,7 @@ class Idefics2VisionAttention(nn.Module):
) # batch_size, q_len, 3 * num_heads_per_partition * head_dim
query_states, key_states, value_states = qkv.chunk(3, dim=-1)
# Use unified MultiHeadAttention implementation
# Use unified MMEncoderAttention implementation
out = self.attn(query_states, key_states, value_states)
attn_output, _ = self.out_proj(out)
return attn_output

View File

@@ -15,7 +15,7 @@ import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.distributed import (
divide,
get_tensor_model_parallel_rank,
@@ -207,7 +207,7 @@ class InternParallelAttention(nn.Module):
disable_tp=use_data_parallel,
)
self.attn = MultiHeadAttention(
self.attn = MMEncoderAttention(
self.num_heads_per_partition, self.head_dim, self.scale
)

View File

@@ -14,7 +14,7 @@ import torch.nn as nn
from transformers import PretrainedConfig
from transformers.utils import torch_int
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.conv import Conv2dLayer
from vllm.model_executor.layers.layernorm import RMSNorm
@@ -214,8 +214,8 @@ class InternSdpaAttention(nn.Module):
self.projection_layer = nn.Linear(self.dummy_dim, self.embed_dim)
# Use unified MultiHeadAttention with automatic backend selection
self.attn = MultiHeadAttention(self.num_heads, self.head_dim, self.scale)
# Use unified MMEncoderAttention with automatic backend selection
self.attn = MMEncoderAttention(self.num_heads, self.head_dim, self.scale)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""x shape: (B, N, C)"""
@@ -228,7 +228,7 @@ class InternSdpaAttention(nn.Module):
q = self.q_norm(q)
k = self.k_norm(k)
# Use unified MultiHeadAttention with automatic backend selection
# Use unified MMEncoderAttention with automatic backend selection
x = self.attn(q, k, v)
x = self.projection_layer(x)

View File

@@ -31,7 +31,7 @@ from transformers.models.llama4.image_processing_llama4_fast import (
get_best_fit,
)
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
@@ -255,7 +255,7 @@ class Llama4VisionAttention(nn.Module):
self.attention_dropout = config.attention_dropout
self.scaling = self.head_dim**-0.5
self.attn = MultiHeadAttention(
self.attn = MMEncoderAttention(
self.num_local_heads, self.head_dim, self.scaling
)

View File

@@ -17,7 +17,8 @@ from transformers import BatchFeature, PretrainedConfig, ProcessorMixin, TensorT
from transformers.image_utils import ImageInput
from transformers.tokenization_utils_base import TextInput
from vllm.attention.layer import Attention, MultiHeadAttention
from vllm.attention.layer import Attention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
@@ -222,7 +223,7 @@ class MultiHeadDotProductAttention(nn.Module):
)
self.scale = self.head_dim**-0.5
self.attn = MultiHeadAttention(
self.attn = MMEncoderAttention(
self.num_heads, self.head_dim, self.scale, num_kv_heads=self.num_kv_heads
)

View File

@@ -16,8 +16,8 @@ from transformers import (
SiglipVisionConfig,
)
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import divide, get_tensor_model_parallel_world_size
@@ -379,7 +379,7 @@ class SiglipAttention(nn.Module):
quant_config: QuantizationConfig | None = None,
*,
prefix: str = "",
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention],
attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
) -> None:
super().__init__()
@@ -481,7 +481,7 @@ class SiglipEncoderLayer(nn.Module):
quant_config: QuantizationConfig | None = None,
*,
prefix: str = "",
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention],
attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
) -> None:
super().__init__()
@@ -527,7 +527,7 @@ class SiglipEncoder(nn.Module):
num_hidden_layers_override: int | None = None,
*,
prefix: str = "",
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention],
attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
) -> None:
super().__init__()
@@ -700,7 +700,7 @@ class SiglipVisionTransformer(nn.Module):
quant_config=quant_config,
num_hidden_layers_override=num_hidden_layers_override,
prefix=f"{prefix}.encoder",
attn_cls=MultiHeadAttention,
attn_cls=MMEncoderAttention,
)
num_hidden_layers = config.num_hidden_layers

View File

@@ -15,7 +15,7 @@ from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from transformers import BatchFeature, PretrainedConfig, TensorType
from vllm.attention.layer import MultiHeadAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
@@ -753,8 +753,8 @@ class Step3VisionAttention(nn.Module):
disable_tp=use_data_parallel,
)
# Use unified MultiHeadAttention with automatic backend selection
self.attn = MultiHeadAttention(self.num_heads, self.head_dim, self.scale)
# Use unified MMEncoderAttention with automatic backend selection
self.attn = MMEncoderAttention(self.num_heads, self.head_dim, self.scale)
def forward(
self,
@@ -767,7 +767,7 @@ class Step3VisionAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
# Use unified MultiHeadAttention with automatic backend selection
# Use unified MMEncoderAttention with automatic backend selection
attn_output = self.attn(q, k, v)
attn_output, _ = self.out_proj(attn_output)

View File

@@ -16,9 +16,9 @@ from transformers import (
)
from transformers.models.whisper.modeling_whisper import sinusoids
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.layer import Attention, MultiHeadAttention
from vllm.attention.layer import Attention, AttentionType
from vllm.attention.layers.cross_attention import CrossAttention
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.distributed import get_tensor_model_parallel_world_size
@@ -141,7 +141,7 @@ class WhisperAudioInputs(TensorSchema):
]
class WhisperEncoderAttention(MultiHeadAttention):
class WhisperEncoderAttention(MMEncoderAttention):
"""Multi-headed attention for Whisper encoder with 2D tensor support."""
def forward(