[Bugfix] Refactor composite weight loading logic (#8656)
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
# Adapted from https://github.com/fixie-ai/ultravox/blob/ecd58c4041030bae2ad15aa6bcf04ab43199ea02/ultravox/model/ultravox_model.py
|
||||
"""PyTorch Ultravox model."""
|
||||
|
||||
import itertools
|
||||
import math
|
||||
from array import array
|
||||
from functools import lru_cache
|
||||
@@ -29,7 +28,8 @@ from vllm.model_executor.layers.quantization.base_config import (
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.interfaces import SupportsMultiModal
|
||||
from vllm.model_executor.models.utils import (filter_weights, flatten_bn,
|
||||
from vllm.model_executor.models.utils import (flatten_bn,
|
||||
group_weights_with_prefix,
|
||||
init_vllm_registered_model,
|
||||
merge_multimodal_embeddings)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
@@ -467,11 +467,10 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
# prepare weight iterators for components
|
||||
projector_weights, llm_weights = itertools.tee(weights, 2)
|
||||
weights_group = group_weights_with_prefix(weights)
|
||||
|
||||
# load projector weights
|
||||
projector_weights = filter_weights(projector_weights,
|
||||
"multi_modal_projector")
|
||||
projector_weights = weights_group["multi_modal_projector"]
|
||||
projector_params_dict = dict(
|
||||
self.multi_modal_projector.named_parameters())
|
||||
for name, loaded_weight in projector_weights:
|
||||
@@ -481,5 +480,4 @@ class UltravoxModel(nn.Module, SupportsMultiModal):
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# load llm backbone
|
||||
llm_weights = filter_weights(llm_weights, "language_model")
|
||||
self.language_model.load_weights(llm_weights)
|
||||
self.language_model.load_weights(weights_group["language_model"])
|
||||
|
||||
Reference in New Issue
Block a user