[Core] Refactor model loading code (#4097)

This commit is contained in:
Antoni Baum
2024-04-16 11:34:39 -07:00
committed by GitHub
parent 05434764cd
commit 69e1d2fb69
67 changed files with 1054 additions and 963 deletions

View File

@@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Iterable, List, Optional, Tuple
import torch
from torch import nn
@@ -13,10 +13,9 @@ from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.model_executor.weight_utils import (default_weight_loader,
hf_model_weights_iterator)
from vllm.sequence import SamplerOutput
_KEYS_TO_MODIFY_MAPPING = {
@@ -198,11 +197,7 @@ class LlavaForConditionalGeneration(nn.Module):
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
def load_weights(self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
load_format: str = "auto",
revision: Optional[str] = None):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# only doing this for language model part for now.
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
@@ -213,8 +208,7 @@ class LlavaForConditionalGeneration(nn.Module):
("gate_up_proj", "up_proj", 1),
]
params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision):
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():