[Bugfix] Refactor composite weight loading logic (#8656)

This commit is contained in:
Isotr0py
2024-09-22 12:33:27 +08:00
committed by GitHub
parent d66ac62854
commit 13d88d4137
7 changed files with 70 additions and 61 deletions

View File

@@ -1,3 +1,5 @@
import itertools
from collections import UserDict
from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple,
Union, overload)
@@ -16,7 +18,23 @@ from vllm.sequence import IntermediateTensors
from vllm.utils import is_pin_memory_available
def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str):
class WeightsGroup(UserDict):
"""
Wraps grouped weights dictionary for a more informative error message
when attempting to access a weight component that does not exist.
"""
def __getitem__(self, key: str) -> int:
try:
return super().__getitem__(key)
except KeyError as exc:
msg = (f"There is no weights named with the prefix: {key}. "
f"Available prefix: {set(self.keys())}")
raise KeyError(msg) from exc
def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]],
prefix: str) -> Iterable[Tuple[str, torch.Tensor]]:
"""
Helper function to load weights for inner vLLM models.
@@ -30,6 +48,22 @@ def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str):
yield name, loaded_weight
def group_weights_with_prefix(
weights: Iterable[Tuple[str, torch.Tensor]]
) -> Dict[str, Iterable[Tuple[str, torch.Tensor]]]:
"""
Helper function to group weights with prefix
"""
init_weights, repeated_weights = itertools.tee(weights, 2)
weights_prefix = {name.split(".")[0] for name, _ in init_weights}
repeated_weights = itertools.tee(repeated_weights, len(weights_prefix))
return WeightsGroup({
prefix: filter_weights(component, prefix)
for component, prefix in zip(repeated_weights, weights_prefix)
})
def init_vllm_registered_model(
hf_config: PretrainedConfig,
cache_config: Optional[CacheConfig],