[Bugfix] fix composite weight loading and EAGLE weight loading (#9160)
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import itertools
|
||||
from collections import UserDict
|
||||
from typing import (Any, Dict, Iterable, List, Literal, Optional, Protocol,
|
||||
Tuple, Union, overload)
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
|
||||
Protocol, Tuple, Union, overload)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -12,55 +12,184 @@ from vllm.config import (CacheConfig, LoRAConfig, MultiModalConfig,
|
||||
SchedulerConfig)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.loader import build_model
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.multimodal.base import NestedTensors
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
WeightsMapping = Mapping[str, Optional[str]]
|
||||
"""If a key maps to a value of `None`, the corresponding weight is ignored."""
|
||||
|
||||
class WeightsGroup(UserDict):
|
||||
|
||||
@dataclass
|
||||
class WeightsMapper:
|
||||
"""Maps the name of each weight if they match the following patterns."""
|
||||
|
||||
orig_to_new_substr: WeightsMapping = field(default_factory=dict)
|
||||
orig_to_new_prefix: WeightsMapping = field(default_factory=dict)
|
||||
orig_to_new_suffix: WeightsMapping = field(default_factory=dict)
|
||||
|
||||
def _map_name(self, key: str) -> Optional[str]:
|
||||
for substr, new_key in self.orig_to_new_substr.items():
|
||||
if substr in key:
|
||||
if new_key is None:
|
||||
return None
|
||||
|
||||
key = key.replace(substr, new_key, 1)
|
||||
|
||||
for prefix, new_key in self.orig_to_new_prefix.items():
|
||||
if key.startswith(prefix):
|
||||
if new_key is None:
|
||||
return None
|
||||
|
||||
key = key.replace(prefix, new_key, 1)
|
||||
|
||||
for suffix, new_key in self.orig_to_new_suffix.items():
|
||||
if key.endswith(suffix):
|
||||
if new_key is None:
|
||||
return None
|
||||
|
||||
key = new_key.join(key.rsplit(suffix, 1))
|
||||
|
||||
return key
|
||||
|
||||
def apply(
|
||||
self, weights: Iterable[Tuple[str, torch.Tensor]]
|
||||
) -> Iterable[Tuple[str, torch.Tensor]]:
|
||||
return ((out_name, data) for name, data in weights
|
||||
if (out_name := self._map_name(name)) is not None)
|
||||
|
||||
|
||||
class AutoWeightsLoader:
|
||||
"""
|
||||
Wraps grouped weights dictionary for a more informative error message
|
||||
when attempting to access a weight component that does not exist.
|
||||
Helper class to load weights into a :class:`torch.nn.Module`. It is able
|
||||
to automatically detect child modules and parameters while iterating over
|
||||
the weights only once.
|
||||
|
||||
The weight loading logic for individual modules can be overridden
|
||||
by defining a ``load_weights`` method.
|
||||
|
||||
Similarly, the weight loading logic for individual parameters can be
|
||||
overridden by defining a ``weight_loader`` method.
|
||||
"""
|
||||
|
||||
def __getitem__(self, key: str) -> Iterable[Tuple[str, torch.Tensor]]:
|
||||
try:
|
||||
return super().__getitem__(key)
|
||||
except KeyError as exc:
|
||||
msg = (f"There is no weights named with the prefix: {key}. "
|
||||
f"Available prefix: {set(self.keys())}")
|
||||
raise KeyError(msg) from exc
|
||||
def __init__(
|
||||
self,
|
||||
module: nn.Module,
|
||||
*,
|
||||
skip_prefixes: Optional[List[str]] = None,
|
||||
ignore_unexpected_prefixes: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.module = module
|
||||
self.skip_prefixes = skip_prefixes or []
|
||||
self.ignore_unexpected_prefixes = ignore_unexpected_prefixes or []
|
||||
|
||||
def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
prefix: str) -> Iterable[Tuple[str, torch.Tensor]]:
|
||||
"""
|
||||
Helper function to load weights for inner vLLM models.
|
||||
def _groupby_prefix(
|
||||
self,
|
||||
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
) -> Iterable[Tuple[str, Iterable[Tuple[str, torch.Tensor]]]]:
|
||||
weights_by_parts = ((weight_name.split(".", 1), weight_data)
|
||||
for weight_name, weight_data in weights)
|
||||
|
||||
See also:
|
||||
:ref:`init_vllm_registered_model`
|
||||
"""
|
||||
for name, loaded_weight in weights:
|
||||
name = name.split(".")
|
||||
if prefix == name.pop(0):
|
||||
name = ".".join(name)
|
||||
yield name, loaded_weight
|
||||
for prefix, group in itertools.groupby(weights_by_parts,
|
||||
key=lambda x: x[0][0]):
|
||||
yield (
|
||||
prefix,
|
||||
# Because maxsplit=1 in weight_name.split(...),
|
||||
# the length of `parts` must either be 1 or 2
|
||||
(("" if len(parts) == 1 else parts[1], weights_data)
|
||||
for parts, weights_data in group),
|
||||
)
|
||||
|
||||
def _get_qualname(self, prefix: str, rest: str) -> str:
|
||||
if prefix == "":
|
||||
return rest
|
||||
if rest == "":
|
||||
return prefix
|
||||
|
||||
def group_weights_with_prefix(
|
||||
weights: Iterable[Tuple[str, torch.Tensor]], ) -> WeightsGroup:
|
||||
"""
|
||||
Helper function to group weights with prefix
|
||||
"""
|
||||
init_weights, repeated_weights = itertools.tee(weights, 2)
|
||||
weights_prefix = {name.split(".")[0] for name, _ in init_weights}
|
||||
repeated_weights = itertools.tee(repeated_weights, len(weights_prefix))
|
||||
return ".".join((prefix, rest))
|
||||
|
||||
return WeightsGroup({
|
||||
prefix: filter_weights(component, prefix)
|
||||
for component, prefix in zip(repeated_weights, weights_prefix)
|
||||
})
|
||||
def _can_skip(self, qualname: str) -> bool:
|
||||
return any(qualname.startswith(p) for p in self.skip_prefixes)
|
||||
|
||||
def _can_ignore_unexpected(self, qualname: str) -> bool:
|
||||
return any(
|
||||
qualname.startswith(p) for p in self.ignore_unexpected_prefixes)
|
||||
|
||||
def _load_param(
|
||||
self,
|
||||
base_prefix: str,
|
||||
param: nn.Parameter,
|
||||
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
) -> None:
|
||||
for weight_name, weight_data in weights:
|
||||
weight_qualname = self._get_qualname(base_prefix, weight_name)
|
||||
|
||||
if self._can_skip(weight_qualname):
|
||||
continue
|
||||
|
||||
if weight_name != "":
|
||||
if not self._can_ignore_unexpected(weight_qualname):
|
||||
raise ValueError(
|
||||
f"Attempted to load nested weight '{weight_qualname}' "
|
||||
f"into a single parameter '{base_prefix}'")
|
||||
|
||||
continue
|
||||
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, weight_data)
|
||||
|
||||
def _load_module(
|
||||
self,
|
||||
base_prefix: str,
|
||||
module: nn.Module,
|
||||
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
) -> None:
|
||||
if isinstance(module, PPMissingLayer):
|
||||
return
|
||||
|
||||
# Avoid infinite recursion since this function is typically
|
||||
# called inside load_weights of the module itself
|
||||
if module != self.module:
|
||||
module_load_weights = getattr(module, "load_weights", None)
|
||||
if callable(module_load_weights):
|
||||
module_load_weights(weights)
|
||||
return
|
||||
|
||||
child_modules = dict(module.named_children())
|
||||
child_params = dict(module.named_parameters(recurse=False))
|
||||
|
||||
for child_prefix, child_weights in self._groupby_prefix(weights):
|
||||
prefix = self._get_qualname(base_prefix, child_prefix)
|
||||
|
||||
if self._can_skip(prefix):
|
||||
continue
|
||||
|
||||
if child_prefix in child_modules:
|
||||
self._load_module(prefix, child_modules[child_prefix],
|
||||
child_weights)
|
||||
elif child_prefix in child_params:
|
||||
self._load_param(prefix, child_params[child_prefix],
|
||||
child_weights)
|
||||
else:
|
||||
if not self._can_ignore_unexpected(prefix):
|
||||
msg = f"There is no module or parameter named '{prefix}'"
|
||||
raise ValueError(msg)
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
weights: Iterable[Tuple[str, torch.Tensor]],
|
||||
*,
|
||||
mapper: Optional[WeightsMapper] = None,
|
||||
) -> None:
|
||||
if mapper is not None:
|
||||
weights = mapper.apply(weights)
|
||||
|
||||
self._load_module("", self.module, weights)
|
||||
|
||||
|
||||
def init_vllm_registered_model(
|
||||
|
||||
Reference in New Issue
Block a user