[Bugfix]: serialize config by value for --trust-remote-code (#6751)
Signed-off-by: Travis Johnson <tsjohnso@us.ibm.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
@@ -232,6 +232,68 @@ def get_config(
|
||||
return config
|
||||
|
||||
|
||||
def maybe_register_config_serialize_by_value(trust_remote_code: bool) -> None:
|
||||
"""Try to register HF model configuration class to serialize by value
|
||||
|
||||
With trust_remote_code, the config class is typically an instance of a
|
||||
custom class imported from the HF modules cache. The class will not be
|
||||
importable in spawned workers by default (and won't exist at all on
|
||||
other nodes), which breaks serialization of the config.
|
||||
|
||||
In this function we tell the cloudpickle serialization library to pass
|
||||
instances of these generated classes by value instead of by reference,
|
||||
i.e. the class definition is serialized along with its data so that the
|
||||
class module does not need to be importable on the receiving end. This
|
||||
registration only works if the modules cache has already been
|
||||
initialized.
|
||||
|
||||
|
||||
See: https://github.com/cloudpipe/cloudpickle?tab=readme-ov-file#overriding-pickles-serialization-mechanism-for-importable-constructs
|
||||
"""
|
||||
if not trust_remote_code:
|
||||
return
|
||||
|
||||
try:
|
||||
import transformers_modules
|
||||
except ImportError:
|
||||
logger.debug("Could not import transformers_modules used for remote"
|
||||
" code. If remote code is not needed remove"
|
||||
" `--trust-remote-code`.")
|
||||
return
|
||||
|
||||
try:
|
||||
import cloudpickle
|
||||
cloudpickle.register_pickle_by_value(transformers_modules)
|
||||
|
||||
# ray vendors its own version of cloudpickle
|
||||
from vllm.executor.ray_utils import ray
|
||||
if ray:
|
||||
ray.cloudpickle.register_pickle_by_value(transformers_modules)
|
||||
|
||||
# multiprocessing uses pickle to serialize arguments when using spawn
|
||||
# Here we get pickle to use cloudpickle to serialize ModelConfig objects
|
||||
# that contain instances of the custom config class to avoid
|
||||
# serialization problems if the generated module (and model) has a `.`
|
||||
# in its name
|
||||
import multiprocessing
|
||||
import pickle
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
def _reduce_modelconfig(mc: ModelConfig):
|
||||
return (pickle.loads, (cloudpickle.dumps(mc), ))
|
||||
|
||||
multiprocessing.reducer.register(ModelConfig, _reduce_modelconfig)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Unable to register remote classes used by"
|
||||
" trust_remote_code with by-value serialization. This may"
|
||||
" lead to a later error. If remote code is not needed"
|
||||
" remove `--trust-remote-code`",
|
||||
exc_info=e)
|
||||
|
||||
|
||||
def load_params_config(model, revision) -> PretrainedConfig:
|
||||
# This function loads a params.json config which
|
||||
# should be used when loading models in mistral format
|
||||
|
||||
Reference in New Issue
Block a user