[Bugfix] Refactor composite weight loading logic (#8656)

This commit is contained in:
Isotr0py
2024-09-22 12:33:27 +08:00
committed by GitHub
parent d66ac62854
commit 13d88d4137
7 changed files with 70 additions and 61 deletions

View File

@@ -1,7 +1,6 @@
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
"""PyTorch Ultravox model."""
import itertools
import math
from array import array
from functools import lru_cache
@@ -29,7 +28,8 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.interfaces import SupportsMultiModal
from vllm.model_executor.models.utils import (filter_weights, flatten_bn,
from vllm.model_executor.models.utils import (flatten_bn,
group_weights_with_prefix,
init_vllm_registered_model,
merge_multimodal_embeddings)
from vllm.model_executor.sampling_metadata import SamplingMetadata
@@ -467,11 +467,10 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# prepare weight iterators for components
projector_weights, llm_weights = itertools.tee(weights, 2)
weights_group = group_weights_with_prefix(weights)
# load projector weights
projector_weights = filter_weights(projector_weights,
"multi_modal_projector")
projector_weights = weights_group["multi_modal_projector"]
projector_params_dict = dict(
self.multi_modal_projector.named_parameters())
for name, loaded_weight in projector_weights:
@@ -481,5 +480,4 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
weight_loader(param, loaded_weight)
# load llm backbone
llm_weights = filter_weights(llm_weights, "language_model")
self.language_model.load_weights(llm_weights)
self.language_model.load_weights(weights_group["language_model"])