[Model] Standardize common vision encoders (#31947)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -11,7 +11,6 @@ from torch.nn import functional as F
|
||||
from transformers import Siglip2VisionConfig
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from vllm.attention.backends.registry import AttentionBackendEnum
|
||||
from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
|
||||
from vllm.config import MultiModalConfig
|
||||
from vllm.distributed import divide, get_tensor_model_parallel_world_size
|
||||
@@ -186,7 +185,6 @@ class Siglip2Attention(nn.Module):
|
||||
multimodal_config: MultiModalConfig | None = None,
|
||||
prefix: str = "",
|
||||
use_data_parallel: bool = False,
|
||||
attn_backend_override: AttentionBackendEnum | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
@@ -196,12 +194,11 @@ class Siglip2Attention(nn.Module):
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f"embed_dim must be divisible by num_heads "
|
||||
f"(got `embed_dim`: {self.embed_dim} and `num_heads`:"
|
||||
f" {self.num_heads})."
|
||||
f"(got `embed_dim`: {self.embed_dim} and "
|
||||
f"`num_heads`: {self.num_heads})."
|
||||
)
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.dropout = config.attention_dropout
|
||||
self.is_causal = False
|
||||
|
||||
use_data_parallel = (
|
||||
multimodal_config.mm_encoder_tp_mode == "data"
|
||||
@@ -233,6 +230,7 @@ class Siglip2Attention(nn.Module):
|
||||
self.attn = MMEncoderAttention(
|
||||
num_heads=self.num_heads_per_partition,
|
||||
head_size=self.head_dim,
|
||||
scale=self.scale,
|
||||
prefix=f"{prefix}.attn",
|
||||
multimodal_config=multimodal_config,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user