[Bugfix] Refactor composite weight loading logic (#8656)
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
import itertools
|
||||
from collections import UserDict
|
||||
from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple,
|
||||
Union, overload)
|
||||
|
||||
@@ -16,7 +18,23 @@ from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
|
||||
def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str):
|
||||
class WeightsGroup(UserDict):
|
||||
"""
|
||||
Wraps grouped weights dictionary for a more informative error message
|
||||
when attempting to access a weight component that does not exist.
|
||||
"""
|
||||
|
||||
def __getitem__(self, key: str) -> int:
|
||||
try:
|
||||
return super().__getitem__(key)
|
||||
except KeyError as exc:
|
||||
msg = (f"There is no weights named with the prefix: {key}. "
|
||||
f"Available prefix: {set(self.keys())}")
|
||||
raise KeyError(msg) from exc
|
||||
|
||||
|
||||
def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
prefix: str) -> Iterable[Tuple[str, torch.Tensor]]:
|
||||
"""
|
||||
Helper function to load weights for inner vLLM models.
|
||||
|
||||
@@ -30,6 +48,22 @@ def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], prefix: str):
|
||||
yield name, loaded_weight
|
||||
|
||||
|
||||
def group_weights_with_prefix(
|
||||
weights: Iterable[Tuple[str, torch.Tensor]]
|
||||
) -> Dict[str, Iterable[Tuple[str, torch.Tensor]]]:
|
||||
"""
|
||||
Helper function to group weights with prefix
|
||||
"""
|
||||
init_weights, repeated_weights = itertools.tee(weights, 2)
|
||||
weights_prefix = {name.split(".")[0] for name, _ in init_weights}
|
||||
repeated_weights = itertools.tee(repeated_weights, len(weights_prefix))
|
||||
|
||||
return WeightsGroup({
|
||||
prefix: filter_weights(component, prefix)
|
||||
for component, prefix in zip(repeated_weights, weights_prefix)
|
||||
})
|
||||
|
||||
|
||||
def init_vllm_registered_model(
|
||||
hf_config: PretrainedConfig,
|
||||
cache_config: Optional[CacheConfig],
|
||||
|
||||
Reference in New Issue
Block a user