[Model] Refactor and decouple weight loading logic for InternVL2 model (#7067)

This commit is contained in:
Isotr0py
2024-08-03 13:36:14 +08:00
committed by GitHub
parent a0d164567c
commit 0c25435daa
2 changed files with 38 additions and 55 deletions

View File

@@ -4,7 +4,7 @@
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from typing import Optional
from typing import Iterable, Optional, Tuple
import torch
import torch.nn as nn
@@ -16,6 +16,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
NORM2FN = {
'rms_norm': RMSNorm,
@@ -268,3 +269,11 @@ class InternVisionModel(nn.Module):
encoder_outputs = self.encoder(inputs_embeds=hidden_states)
return encoder_outputs
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)