[Bugfix] Fix PP for ChatGLM and Molmo (#9422)
This commit is contained in:
@@ -79,6 +79,9 @@ class AutoWeightsLoader:
|
||||
|
||||
Similarly, the weight loading logic for individual parameters can be
|
||||
overridden by defining a ``weight_loader`` method.
|
||||
|
||||
Detailed weight loading information can be viewed by setting the
|
||||
environment variable ``VLLM_LOGGING_LEVEL=DEBUG``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -136,20 +139,27 @@ class AutoWeightsLoader:
|
||||
weight_qualname = self._get_qualname(base_prefix, weight_name)
|
||||
|
||||
if self._can_skip(weight_qualname):
|
||||
logger.debug("Skipping weight %s", weight_qualname)
|
||||
|
||||
continue
|
||||
|
||||
if weight_name != "":
|
||||
if not self._can_ignore_unexpected(weight_qualname):
|
||||
raise ValueError(
|
||||
f"Attempted to load nested weight '{weight_qualname}' "
|
||||
f"into a single parameter '{base_prefix}'")
|
||||
if self._can_ignore_unexpected(weight_qualname):
|
||||
logger.debug("Ignoring weight %s", weight_qualname)
|
||||
|
||||
continue
|
||||
continue
|
||||
|
||||
raise ValueError(
|
||||
f"Attempted to load nested weight '{weight_qualname}' "
|
||||
f"into a single parameter '{base_prefix}'")
|
||||
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, weight_data)
|
||||
|
||||
logger.debug("Loaded weight %s with shape %s", weight_qualname,
|
||||
param.shape)
|
||||
|
||||
yield weight_qualname
|
||||
|
||||
def _load_module(
|
||||
@@ -175,21 +185,41 @@ class AutoWeightsLoader:
|
||||
for child_prefix, child_weights in self._groupby_prefix(weights):
|
||||
prefix = self._get_qualname(base_prefix, child_prefix)
|
||||
|
||||
if self._can_skip(prefix):
|
||||
continue
|
||||
|
||||
if child_prefix in child_modules:
|
||||
if self._can_skip(prefix + "."):
|
||||
logger.debug("Skipping module %s", prefix)
|
||||
|
||||
continue
|
||||
|
||||
yield from self._load_module(prefix,
|
||||
child_modules[child_prefix],
|
||||
child_weights)
|
||||
elif child_prefix in child_params:
|
||||
if self._can_skip(prefix):
|
||||
logger.debug("Skipping param %s", prefix)
|
||||
|
||||
continue
|
||||
|
||||
yield from self._load_param(prefix, child_params[child_prefix],
|
||||
child_weights)
|
||||
else:
|
||||
if not self._can_ignore_unexpected(prefix):
|
||||
msg = (f"There is no module or parameter named '{prefix}' "
|
||||
f"in {type(self.module).__name__}")
|
||||
raise ValueError(msg)
|
||||
can_skip_module = self._can_skip(prefix + ".")
|
||||
can_skip_param = self._can_skip(prefix)
|
||||
if can_skip_module or can_skip_param:
|
||||
logger.debug("Skipping missing %s", prefix)
|
||||
|
||||
continue
|
||||
|
||||
can_ignore_module = self._can_ignore_unexpected(prefix + ".")
|
||||
can_ignore_param = self._can_ignore_unexpected(prefix)
|
||||
if can_ignore_module or can_ignore_param:
|
||||
logger.debug("Ignoring missing %s", prefix)
|
||||
|
||||
continue
|
||||
|
||||
msg = (f"There is no module or parameter named '{prefix}' "
|
||||
f"in {type(self.module).__name__}")
|
||||
raise ValueError(msg)
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user