[Bugfix] fix composite weight loading and EAGLE weight loading (#9160)

This commit is contained in:
Cyrus Leung
2024-10-09 15:36:55 +08:00
committed by GitHub
parent 0b5b5d767e
commit 8bfaa4e31e
15 changed files with 241 additions and 361 deletions

View File

@@ -1,7 +1,7 @@
import itertools
from collections import UserDict
from typing import (Any, Dict, Iterable, List, Literal, Optional, Protocol,
Tuple, Union, overload)
from dataclasses import dataclass, field
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
Protocol, Tuple, Union, overload)
import torch
import torch.nn as nn
@@ -12,55 +12,184 @@ from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig,
SchedulerConfig)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.loader import build_model
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import ModelRegistry
from vllm.multimodal.base import NestedTensors
from vllm.sequence import IntermediateTensors
from vllm.utils import is_pin_memory_available
WeightsMapping = Mapping[str, Optional[str]]
"""If a key maps to a value of `None`, the corresponding weight is ignored."""
class WeightsGroup(UserDict):
@dataclass
class WeightsMapper:
"""Maps the name of each weight if they match the following patterns."""
orig_to_new_substr: WeightsMapping = field(default_factory=dict)
orig_to_new_prefix: WeightsMapping = field(default_factory=dict)
orig_to_new_suffix: WeightsMapping = field(default_factory=dict)
def _map_name(self, key: str) -> Optional[str]:
for substr, new_key in self.orig_to_new_substr.items():
if substr in key:
if new_key is None:
return None
key = key.replace(substr, new_key, 1)
for prefix, new_key in self.orig_to_new_prefix.items():
if key.startswith(prefix):
if new_key is None:
return None
key = key.replace(prefix, new_key, 1)
for suffix, new_key in self.orig_to_new_suffix.items():
if key.endswith(suffix):
if new_key is None:
return None
key = new_key.join(key.rsplit(suffix, 1))
return key
def apply(
self, weights: Iterable[Tuple[str, torch.Tensor]]
) -> Iterable[Tuple[str, torch.Tensor]]:
return ((out_name, data) for name, data in weights
if (out_name := self._map_name(name)) is not None)
class AutoWeightsLoader:
"""
Wraps grouped weights dictionary for a more informative error message
when attempting to access a weight component that does not exist.
Helper class to load weights into a :class:`torch.nn.Module`. It is able
to automatically detect child modules and parameters while iterating over
the weights only once.
The weight loading logic for individual modules can be overridden
by defining a ``load_weights`` method.
Similarly, the weight loading logic for individual parameters can be
overridden by defining a ``weight_loader`` method.
"""
def __getitem__(self, key: str) -> Iterable[Tuple[str, torch.Tensor]]:
try:
return super().__getitem__(key)
except KeyError as exc:
msg = (f"There is no weights named with the prefix: {key}. "
f"Available prefix: {set(self.keys())}")
raise KeyError(msg) from exc
def __init__(
self,
module: nn.Module,
*,
skip_prefixes: Optional[List[str]] = None,
ignore_unexpected_prefixes: Optional[List[str]] = None,
) -> None:
super().__init__()
self.module = module
self.skip_prefixes = skip_prefixes or []
self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or []
def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]],
prefix: str) -> Iterable[Tuple[str, torch.Tensor]]:
"""
Helper function to load weights for inner vLLM models.
def _groupby_prefix(
self,
weights: Iterable[Tuple[str, torch.Tensor]],
) -> Iterable[Tuple[str, Iterable[Tuple[str, torch.Tensor]]]]:
weights_by_parts = ((weight_name.split(".", 1), weight_data)
for weight_name, weight_data in weights)
See also:
:ref:`init_vllm_registered_model`
"""
for name, loaded_weight in weights:
name = name.split(".")
if prefix == name.pop(0):
name = ".".join(name)
yield name, loaded_weight
for prefix, group in itertools.groupby(weights_by_parts,
key=lambda x: x[0][0]):
yield (
prefix,
# Because maxsplit=1 in weight_name.split(...),
# the length of `parts` must either be 1 or 2
(("" if len(parts) == 1 else parts[1], weights_data)
for parts, weights_data in group),
)
def _get_qualname(self, prefix: str, rest: str) -> str:
if prefix == "":
return rest
if rest == "":
return prefix
def group_weights_with_prefix(
weights: Iterable[Tuple[str, torch.Tensor]], ) -> WeightsGroup:
"""
Helper function to group weights with prefix
"""
init_weights, repeated_weights = itertools.tee(weights, 2)
weights_prefix = {name.split(".")[0] for name, _ in init_weights}
repeated_weights = itertools.tee(repeated_weights, len(weights_prefix))
return ".".join((prefix, rest))
return WeightsGroup({
prefix: filter_weights(component, prefix)
for component, prefix in zip(repeated_weights, weights_prefix)
})
def _can_skip(self, qualname: str) -> bool:
return any(qualname.startswith(p) for p in self.skip_prefixes)
def _can_ignore_unexpected(self, qualname: str) -> bool:
return any(
qualname.startswith(p) for p in self.ignore_unexpected_prefixes)
def _load_param(
self,
base_prefix: str,
param: nn.Parameter,
weights: Iterable[Tuple[str, torch.Tensor]],
) -> None:
for weight_name, weight_data in weights:
weight_qualname = self._get_qualname(base_prefix, weight_name)
if self._can_skip(weight_qualname):
continue
if weight_name != "":
if not self._can_ignore_unexpected(weight_qualname):
raise ValueError(
f"Attempted to load nested weight '{weight_qualname}' "
f"into a single parameter '{base_prefix}'")
continue
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, weight_data)
def _load_module(
self,
base_prefix: str,
module: nn.Module,
weights: Iterable[Tuple[str, torch.Tensor]],
) -> None:
if isinstance(module, PPMissingLayer):
return
# Avoid infinite recursion since this function is typically
# called inside load_weights of the module itself
if module != self.module:
module_load_weights = getattr(module, "load_weights", None)
if callable(module_load_weights):
module_load_weights(weights)
return
child_modules = dict(module.named_children())
child_params = dict(module.named_parameters(recurse=False))
for child_prefix, child_weights in self._groupby_prefix(weights):
prefix = self._get_qualname(base_prefix, child_prefix)
if self._can_skip(prefix):
continue
if child_prefix in child_modules:
self._load_module(prefix, child_modules[child_prefix],
child_weights)
elif child_prefix in child_params:
self._load_param(prefix, child_params[child_prefix],
child_weights)
else:
if not self._can_ignore_unexpected(prefix):
msg = f"There is no module or parameter named '{prefix}'"
raise ValueError(msg)
def load_weights(
self,
weights: Iterable[Tuple[str, torch.Tensor]],
*,
mapper: Optional[WeightsMapper] = None,
) -> None:
if mapper is not None:
weights = mapper.apply(weights)
self._load_module("", self.module, weights)
def init_vllm_registered_model(