[Model] Remove unnecessary weight initialization logic (#11736)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Cyrus Leung
2025-01-04 23:46:21 +08:00
committed by GitHub
parent ba214dffbe
commit 65c08928c2
3 changed files with 5 additions and 22 deletions

View File

@@ -27,7 +27,7 @@
Shared resampler perceiver network used in multimodal models and Shared resampler perceiver network used in multimodal models and
related helpers for sincos positional embeddings. related helpers for sincos positional embeddings.
Example models: Qwen (Qwen-VL), Minicpmv2.0 Example models: Qwen (Qwen-VL), MiniCPM-V 2.0
""" """
import math import math
from functools import partial from functools import partial
@@ -37,7 +37,6 @@ import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from torch.nn.init import trunc_normal_
from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -169,8 +168,8 @@ class BaseResampler(nn.Module):
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_heads = num_heads self.num_heads = num_heads
self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim)) self.query = nn.Parameter(torch.empty(self.num_queries, embed_dim))
trunc_normal_(self.query, std=0.02)
if kv_dim is not None and kv_dim != embed_dim: if kv_dim is not None and kv_dim != embed_dim:
self.kv_proj = ReplicatedLinear(kv_dim, self.kv_proj = ReplicatedLinear(kv_dim,
embed_dim, embed_dim,
@@ -190,16 +189,7 @@ class BaseResampler(nn.Module):
self.ln_post = norm_layer(embed_dim) if do_post_projection else None self.ln_post = norm_layer(embed_dim) if do_post_projection else None
self.proj = nn.Parameter( self.proj = nn.Parameter(
(embed_dim**-0.5) * (embed_dim**-0.5) *
torch.randn(embed_dim, embed_dim)) if do_post_projection else None torch.empty(embed_dim, embed_dim)) if do_post_projection else None
def _init_weights(self, m: nn.Module) -> None:
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def _repeat(self, query, N: int): def _repeat(self, query, N: int):
return query.unsqueeze(1).repeat(1, N, 1) return query.unsqueeze(1).repeat(1, N, 1)
@@ -240,8 +230,6 @@ class Resampler2(BaseResampler):
self.pos_embed = nn.Parameter( self.pos_embed = nn.Parameter(
torch.from_numpy(pos_embed_arr).requires_grad_(False)) torch.from_numpy(pos_embed_arr).requires_grad_(False))
self.apply(self._init_weights)
def forward( def forward(
self, self,
x: torch.Tensor, x: torch.Tensor,

View File

@@ -3,7 +3,6 @@ from typing import (Callable, Iterable, List, Mapping, Optional, Set, Tuple,
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn.init import trunc_normal_
from transformers import BatchFeature, PretrainedConfig from transformers import BatchFeature, PretrainedConfig
from vllm.attention import AttentionMetadata from vllm.attention import AttentionMetadata
@@ -216,9 +215,7 @@ class AriaProjector(nn.Module):
self.num_heads = num_heads self.num_heads = num_heads
self.query = nn.Parameter( self.query = nn.Parameter(
torch.zeros(max(patch_to_query_dict.values()), self.embed_dim)) torch.empty(max(patch_to_query_dict.values()), self.embed_dim))
trunc_normal_(self.query, std=0.02)
self.cross_attn = CrossAttention(kv_dim, embed_dim, num_heads) self.cross_attn = CrossAttention(kv_dim, embed_dim, num_heads)

View File

@@ -141,8 +141,6 @@ class Resampler2_5(BaseResampler):
self.max_size = max_size self.max_size = max_size
self._set_2d_pos_cache(self.max_size) self._set_2d_pos_cache(self.max_size)
self.apply(self._init_weights)
def _set_2d_pos_cache(self, def _set_2d_pos_cache(self,
max_size: Tuple[int, int], max_size: Tuple[int, int],
device: torch.types.Device = "cpu") -> None: device: torch.types.Device = "cpu") -> None: