[Bugfix] Fix PP for ChatGLM and Molmo (#9422)

This commit is contained in:
Cyrus Leung
2024-10-24 14:12:05 +08:00
committed by GitHub
parent 056a68c7db
commit 836e8ef6ee
7 changed files with 195 additions and 122 deletions

View File

@@ -79,6 +79,9 @@ class AutoWeightsLoader:
Similarly, the weight loading logic for individual parameters can be
overridden by defining a ``weight_loader`` method.
Detailed weight loading information can be viewed by setting the
environment variable ``VLLM_LOGGING_LEVEL=DEBUG``.
"""
def __init__(
@@ -136,20 +139,27 @@ class AutoWeightsLoader:
weight_qualname = self._get_qualname(base_prefix, weight_name)
if self._can_skip(weight_qualname):
logger.debug("Skipping weight %s", weight_qualname)
continue
if weight_name != "":
if not self._can_ignore_unexpected(weight_qualname):
raise ValueError(
f"Attempted to load nested weight '{weight_qualname}' "
f"into a single parameter '{base_prefix}'")
if self._can_ignore_unexpected(weight_qualname):
logger.debug("Ignoring weight %s", weight_qualname)
continue
continue
raise ValueError(
f"Attempted to load nested weight '{weight_qualname}' "
f"into a single parameter '{base_prefix}'")
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, weight_data)
logger.debug("Loaded weight %s with shape %s", weight_qualname,
param.shape)
yield weight_qualname
def _load_module(
@@ -175,21 +185,41 @@ class AutoWeightsLoader:
for child_prefix, child_weights in self._groupby_prefix(weights):
prefix = self._get_qualname(base_prefix, child_prefix)
if self._can_skip(prefix):
continue
if child_prefix in child_modules:
if self._can_skip(prefix + "."):
logger.debug("Skipping module %s", prefix)
continue
yield from self._load_module(prefix,
child_modules[child_prefix],
child_weights)
elif child_prefix in child_params:
if self._can_skip(prefix):
logger.debug("Skipping param %s", prefix)
continue
yield from self._load_param(prefix, child_params[child_prefix],
child_weights)
else:
if not self._can_ignore_unexpected(prefix):
msg = (f"There is no module or parameter named '{prefix}' "
f"in {type(self.module).__name__}")
raise ValueError(msg)
can_skip_module = self._can_skip(prefix + ".")
can_skip_param = self._can_skip(prefix)
if can_skip_module or can_skip_param:
logger.debug("Skipping missing %s", prefix)
continue
can_ignore_module = self._can_ignore_unexpected(prefix + ".")
can_ignore_param = self._can_ignore_unexpected(prefix)
if can_ignore_module or can_ignore_param:
logger.debug("Ignoring missing %s", prefix)
continue
msg = (f"There is no module or parameter named '{prefix}' "
f"in {type(self.module).__name__}")
raise ValueError(msg)
def load_weights(
self,