[Model] Remove unnecessary weight initialization logic (#11736)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@@ -27,7 +27,7 @@
|
|||||||
Shared resampler perceiver network used in multimodal models and
|
Shared resampler perceiver network used in multimodal models and
|
||||||
related helpers for sincos positional embeddings.
|
related helpers for sincos positional embeddings.
|
||||||
|
|
||||||
Example models: Qwen (Qwen-VL), Minicpmv2.0
|
Example models: Qwen (Qwen-VL), MiniCPM-V 2.0
|
||||||
"""
|
"""
|
||||||
import math
|
import math
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@@ -37,7 +37,6 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn.init import trunc_normal_
|
|
||||||
|
|
||||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
@@ -169,8 +168,8 @@ class BaseResampler(nn.Module):
|
|||||||
self.embed_dim = embed_dim
|
self.embed_dim = embed_dim
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
|
|
||||||
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim))
|
self.query = nn.Parameter(torch.empty(self.num_queries, embed_dim))
|
||||||
trunc_normal_(self.query, std=0.02)
|
|
||||||
if kv_dim is not None and kv_dim != embed_dim:
|
if kv_dim is not None and kv_dim != embed_dim:
|
||||||
self.kv_proj = ReplicatedLinear(kv_dim,
|
self.kv_proj = ReplicatedLinear(kv_dim,
|
||||||
embed_dim,
|
embed_dim,
|
||||||
@@ -190,16 +189,7 @@ class BaseResampler(nn.Module):
|
|||||||
self.ln_post = norm_layer(embed_dim) if do_post_projection else None
|
self.ln_post = norm_layer(embed_dim) if do_post_projection else None
|
||||||
self.proj = nn.Parameter(
|
self.proj = nn.Parameter(
|
||||||
(embed_dim**-0.5) *
|
(embed_dim**-0.5) *
|
||||||
torch.randn(embed_dim, embed_dim)) if do_post_projection else None
|
torch.empty(embed_dim, embed_dim)) if do_post_projection else None
|
||||||
|
|
||||||
def _init_weights(self, m: nn.Module) -> None:
|
|
||||||
if isinstance(m, nn.Linear):
|
|
||||||
trunc_normal_(m.weight, std=0.02)
|
|
||||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
elif isinstance(m, nn.LayerNorm):
|
|
||||||
nn.init.constant_(m.bias, 0)
|
|
||||||
nn.init.constant_(m.weight, 1.0)
|
|
||||||
|
|
||||||
def _repeat(self, query, N: int):
|
def _repeat(self, query, N: int):
|
||||||
return query.unsqueeze(1).repeat(1, N, 1)
|
return query.unsqueeze(1).repeat(1, N, 1)
|
||||||
@@ -240,8 +230,6 @@ class Resampler2(BaseResampler):
|
|||||||
self.pos_embed = nn.Parameter(
|
self.pos_embed = nn.Parameter(
|
||||||
torch.from_numpy(pos_embed_arr).requires_grad_(False))
|
torch.from_numpy(pos_embed_arr).requires_grad_(False))
|
||||||
|
|
||||||
self.apply(self._init_weights)
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ from typing import (Callable, Iterable, List, Mapping, Optional, Set, Tuple,
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn.init import trunc_normal_
|
|
||||||
from transformers import BatchFeature, PretrainedConfig
|
from transformers import BatchFeature, PretrainedConfig
|
||||||
|
|
||||||
from vllm.attention import AttentionMetadata
|
from vllm.attention import AttentionMetadata
|
||||||
@@ -216,9 +215,7 @@ class AriaProjector(nn.Module):
|
|||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
|
|
||||||
self.query = nn.Parameter(
|
self.query = nn.Parameter(
|
||||||
torch.zeros(max(patch_to_query_dict.values()), self.embed_dim))
|
torch.empty(max(patch_to_query_dict.values()), self.embed_dim))
|
||||||
|
|
||||||
trunc_normal_(self.query, std=0.02)
|
|
||||||
|
|
||||||
self.cross_attn = CrossAttention(kv_dim, embed_dim, num_heads)
|
self.cross_attn = CrossAttention(kv_dim, embed_dim, num_heads)
|
||||||
|
|
||||||
|
|||||||
@@ -141,8 +141,6 @@ class Resampler2_5(BaseResampler):
|
|||||||
self.max_size = max_size
|
self.max_size = max_size
|
||||||
self._set_2d_pos_cache(self.max_size)
|
self._set_2d_pos_cache(self.max_size)
|
||||||
|
|
||||||
self.apply(self._init_weights)
|
|
||||||
|
|
||||||
def _set_2d_pos_cache(self,
|
def _set_2d_pos_cache(self,
|
||||||
max_size: Tuple[int, int],
|
max_size: Tuple[int, int],
|
||||||
device: torch.types.Device = "cpu") -> None:
|
device: torch.types.Device = "cpu") -> None:
|
||||||
|
|||||||
Reference in New Issue
Block a user