[Bugfix]: serialize config by value for --trust-remote-code (#6751)

Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Travis Johnson
2024-10-21 20:46:24 -06:00
committed by GitHub
parent 76a5e13270
commit b729901139
4 changed files with 103 additions and 28 deletions

View File

@@ -232,6 +232,68 @@ def get_config(
return config
def maybe_register_config_serialize_by_value(trust_remote_code: bool) -> None:
"""Try to register HF model configuration class to serialize by value
With trust_remote_code, the config class is typically an instance of a
custom class imported from the HF modules cache. The class will not be
importable in spawned workers by default (and won't exist at all on
other nodes), which breaks serialization of the config.
In this function we tell the cloudpickle serialization library to pass
instances of these generated classes by value instead of by reference,
i.e. the class definition is serialized along with its data so that the
class module does not need to be importable on the receiving end. This
registration only works if the modules cache has already been
initialized.
See: https://github.com/cloudpipe/cloudpickle?tab=readme-ov-file#overriding-pickles-serialization-mechanism-for-importable-constructs
"""
if not trust_remote_code:
return
try:
import transformers_modules
except ImportError:
logger.debug("Could not import transformers_modules used for remote"
" code. If remote code is not needed remove"
" `--trust-remote-code`.")
return
try:
import cloudpickle
cloudpickle.register_pickle_by_value(transformers_modules)
# ray vendors its own version of cloudpickle
from vllm.executor.ray_utils import ray
if ray:
ray.cloudpickle.register_pickle_by_value(transformers_modules)
# multiprocessing uses pickle to serialize arguments when using spawn
# Here we get pickle to use cloudpickle to serialize ModelConfig objects
# that contain instances of the custom config class to avoid
# serialization problems if the generated module (and model) has a `.`
# in its name
import multiprocessing
import pickle
from vllm.config import ModelConfig
def _reduce_modelconfig(mc: ModelConfig):
return (pickle.loads, (cloudpickle.dumps(mc), ))
multiprocessing.reducer.register(ModelConfig, _reduce_modelconfig)
except Exception as e:
logger.warning(
"Unable to register remote classes used by"
" trust_remote_code with by-value serialization. This may"
" lead to a later error. If remote code is not needed"
" remove `--trust-remote-code`",
exc_info=e)
def load_params_config(model, revision) -> PretrainedConfig:
# This function loads a params.json config which
# should be used when loading models in mistral format