[MM Encoder]: Migrate legacy ViT MultiHeadAttention to new MMEncoderAttention interface (#30684)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
@@ -8,7 +8,7 @@ from collections.abc import Iterable
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.utils import divide
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
@@ -126,7 +126,7 @@ class AIMv2Attention(nn.Module):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
|
||||
self.attn = MultiHeadAttention(
|
||||
self.attn = MMEncoderAttention(
|
||||
self.num_heads_per_partition, self.head_dim, self.scale
|
||||
)
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from transformers import Blip2VisionConfig, BlipVisionConfig
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.conv import Conv2dLayer
|
||||
@@ -122,7 +122,7 @@ class BlipAttention(nn.Module):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)
|
||||
|
||||
self.attn = MultiHeadAttention(
|
||||
self.attn = MMEncoderAttention(
|
||||
self.num_heads_per_partition, self.head_dim, self.scale
|
||||
)
|
||||
|
||||
|
||||
@@ -14,7 +14,8 @@ from transformers import (
|
||||
CLIPVisionConfig,
|
||||
)
|
||||
|
||||
from vllm.attention.layer import Attention, MultiHeadAttention
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
@@ -354,7 +355,7 @@ class CLIPAttention(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
attn_cls: type[Attention] | type[MultiHeadAttention],
|
||||
attn_cls: type[Attention] | type[MMEncoderAttention],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -449,7 +450,7 @@ class CLIPEncoderLayer(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
attn_cls: type[Attention] | type[MultiHeadAttention],
|
||||
attn_cls: type[Attention] | type[MMEncoderAttention],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.self_attn = CLIPAttention(
|
||||
@@ -493,7 +494,7 @@ class CLIPEncoder(nn.Module):
|
||||
num_hidden_layers_override: int | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
attn_cls: type[Attention] | type[MultiHeadAttention],
|
||||
attn_cls: type[Attention] | type[MMEncoderAttention],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -638,7 +639,7 @@ class CLIPVisionTransformer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
prefix=f"{prefix}.encoder",
|
||||
attn_cls=MultiHeadAttention,
|
||||
attn_cls=MMEncoderAttention,
|
||||
)
|
||||
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
|
||||
@@ -18,7 +18,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPVisionConfig
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.model_executor.layers.conv import Conv2dLayer
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
@@ -628,7 +628,7 @@ class DeepCLIPVisionTransformer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
prefix=f"{prefix}.encoder",
|
||||
attn_cls=MultiHeadAttention,
|
||||
attn_cls=MMEncoderAttention,
|
||||
)
|
||||
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
|
||||
@@ -19,7 +19,7 @@ from transformers import BatchFeature, PreTrainedTokenizer, TensorType
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
@@ -135,7 +135,7 @@ class EVA2CLIPAttention(nn.Module):
|
||||
prefix=f"{prefix}.dense",
|
||||
)
|
||||
|
||||
self.attn = MultiHeadAttention(
|
||||
self.attn = MMEncoderAttention(
|
||||
self.num_heads_per_rank, self.head_dim, self.scale
|
||||
)
|
||||
self.output_dropout = torch.nn.Dropout(config.dropout_prob)
|
||||
|
||||
@@ -34,7 +34,7 @@ import torch.nn.functional as F
|
||||
from transformers import BatchFeature
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.config import MultiModalConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import parallel_state
|
||||
@@ -232,7 +232,7 @@ class HunYuanVisionAttention(nn.Module):
|
||||
)
|
||||
|
||||
self.scale = self.hidden_size_per_attention_head**-0.5
|
||||
self.attn = MultiHeadAttention(
|
||||
self.attn = MMEncoderAttention(
|
||||
self.num_attention_heads_per_partition,
|
||||
self.hidden_size_per_attention_head,
|
||||
self.scale,
|
||||
|
||||
@@ -27,7 +27,7 @@ from transformers.models.idefics2.configuration_idefics2 import (
|
||||
Idefics2VisionConfig,
|
||||
)
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.conv import Conv2dLayer
|
||||
@@ -161,8 +161,8 @@ class Idefics2VisionAttention(nn.Module):
|
||||
prefix=f"{prefix}.out_proj",
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
# Use unified MultiHeadAttention with Flash Attention support
|
||||
self.attn = MultiHeadAttention(
|
||||
# Use unified MMEncoderAttention with Flash Attention support
|
||||
self.attn = MMEncoderAttention(
|
||||
self.num_heads_per_partition, self.head_dim, self.scale
|
||||
)
|
||||
|
||||
@@ -175,7 +175,7 @@ class Idefics2VisionAttention(nn.Module):
|
||||
) # batch_size, q_len, 3 * num_heads_per_partition * head_dim
|
||||
query_states, key_states, value_states = qkv.chunk(3, dim=-1)
|
||||
|
||||
# Use unified MultiHeadAttention implementation
|
||||
# Use unified MMEncoderAttention implementation
|
||||
out = self.attn(query_states, key_states, value_states)
|
||||
attn_output, _ = self.out_proj(out)
|
||||
return attn_output
|
||||
|
||||
@@ -15,7 +15,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.distributed import (
|
||||
divide,
|
||||
get_tensor_model_parallel_rank,
|
||||
@@ -207,7 +207,7 @@ class InternParallelAttention(nn.Module):
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
|
||||
self.attn = MultiHeadAttention(
|
||||
self.attn = MMEncoderAttention(
|
||||
self.num_heads_per_partition, self.head_dim, self.scale
|
||||
)
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
from transformers.utils import torch_int
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.conv import Conv2dLayer
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@@ -214,8 +214,8 @@ class InternSdpaAttention(nn.Module):
|
||||
|
||||
self.projection_layer = nn.Linear(self.dummy_dim, self.embed_dim)
|
||||
|
||||
# Use unified MultiHeadAttention with automatic backend selection
|
||||
self.attn = MultiHeadAttention(self.num_heads, self.head_dim, self.scale)
|
||||
# Use unified MMEncoderAttention with automatic backend selection
|
||||
self.attn = MMEncoderAttention(self.num_heads, self.head_dim, self.scale)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""x shape: (B, N, C)"""
|
||||
@@ -228,7 +228,7 @@ class InternSdpaAttention(nn.Module):
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
|
||||
# Use unified MultiHeadAttention with automatic backend selection
|
||||
# Use unified MMEncoderAttention with automatic backend selection
|
||||
x = self.attn(q, k, v)
|
||||
|
||||
x = self.projection_layer(x)
|
||||
|
||||
@@ -31,7 +31,7 @@ from transformers.models.llama4.image_processing_llama4_fast import (
|
||||
get_best_fit,
|
||||
)
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
@@ -255,7 +255,7 @@ class Llama4VisionAttention(nn.Module):
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.scaling = self.head_dim**-0.5
|
||||
|
||||
self.attn = MultiHeadAttention(
|
||||
self.attn = MMEncoderAttention(
|
||||
self.num_local_heads, self.head_dim, self.scaling
|
||||
)
|
||||
|
||||
|
||||
@@ -17,7 +17,8 @@ from transformers import BatchFeature, PretrainedConfig, ProcessorMixin, TensorT
|
||||
from transformers.image_utils import ImageInput
|
||||
from transformers.tokenization_utils_base import TextInput
|
||||
|
||||
from vllm.attention.layer import Attention, MultiHeadAttention
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
@@ -222,7 +223,7 @@ class MultiHeadDotProductAttention(nn.Module):
|
||||
)
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.attn = MultiHeadAttention(
|
||||
self.attn = MMEncoderAttention(
|
||||
self.num_heads, self.head_dim, self.scale, num_kv_heads=self.num_kv_heads
|
||||
)
|
||||
|
||||
|
||||
@@ -16,8 +16,8 @@ from transformers import (
|
||||
SiglipVisionConfig,
|
||||
)
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.attention.layers.encoder_only_attention import EncoderOnlyAttention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
@@ -379,7 +379,7 @@ class SiglipAttention(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention],
|
||||
attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -481,7 +481,7 @@ class SiglipEncoderLayer(nn.Module):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention],
|
||||
attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -527,7 +527,7 @@ class SiglipEncoder(nn.Module):
|
||||
num_hidden_layers_override: int | None = None,
|
||||
*,
|
||||
prefix: str = "",
|
||||
attn_cls: type[EncoderOnlyAttention] | type[MultiHeadAttention],
|
||||
attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -700,7 +700,7 @@ class SiglipVisionTransformer(nn.Module):
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override,
|
||||
prefix=f"{prefix}.encoder",
|
||||
attn_cls=MultiHeadAttention,
|
||||
attn_cls=MMEncoderAttention,
|
||||
)
|
||||
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
|
||||
@@ -15,7 +15,7 @@ from torchvision import transforms
|
||||
from torchvision.transforms.functional import InterpolationMode
|
||||
from transformers import BatchFeature, PretrainedConfig, TensorType
|
||||
|
||||
from vllm.attention.layer import MultiHeadAttention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
@@ -753,8 +753,8 @@ class Step3VisionAttention(nn.Module):
|
||||
disable_tp=use_data_parallel,
|
||||
)
|
||||
|
||||
# Use unified MultiHeadAttention with automatic backend selection
|
||||
self.attn = MultiHeadAttention(self.num_heads, self.head_dim, self.scale)
|
||||
# Use unified MMEncoderAttention with automatic backend selection
|
||||
self.attn = MMEncoderAttention(self.num_heads, self.head_dim, self.scale)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -767,7 +767,7 @@ class Step3VisionAttention(nn.Module):
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
|
||||
# Use unified MultiHeadAttention with automatic backend selection
|
||||
# Use unified MMEncoderAttention with automatic backend selection
|
||||
attn_output = self.attn(q, k, v)
|
||||
|
||||
attn_output, _ = self.out_proj(attn_output)
|
||||
|
||||
@@ -16,9 +16,9 @@ from transformers import (
|
||||
)
|
||||
from transformers.models.whisper.modeling_whisper import sinusoids
|
||||
|
||||
from vllm.attention.backends.abstract import AttentionType
|
||||
from vllm.attention.layer import Attention, MultiHeadAttention
|
||||
from vllm.attention.layer import Attention, AttentionType
|
||||
from vllm.attention.layers.cross_attention import CrossAttention
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.config import CacheConfig, ModelConfig, SpeechToTextConfig, VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
@@ -141,7 +141,7 @@ class WhisperAudioInputs(TensorSchema):
|
||||
]
|
||||
|
||||
|
||||
class WhisperEncoderAttention(MultiHeadAttention):
|
||||
class WhisperEncoderAttention(MMEncoderAttention):
|
||||
"""Multi-headed attention for Whisper encoder with 2D tensor support."""
|
||||
|
||||
def forward(
|
||||
|
||||
Reference in New Issue
Block a user