Enable safetensors loading for all models (#974)
This commit is contained in:
@@ -24,9 +24,16 @@ class ModelConfig:
|
||||
downloading the model and tokenizer.
|
||||
download_dir: Directory to download and load the weights, default to the
|
||||
default cache directory of huggingface.
|
||||
use_np_weights: Save a numpy copy of model weights for faster loading.
|
||||
This can increase the disk usage by up to 2x.
|
||||
use_dummy_weights: Use dummy values for model weights (for profiling).
|
||||
load_format: The format of the model weights to load:
|
||||
"auto" will try to load the weights in the safetensors format and
|
||||
fall back to the pytorch bin format if safetensors format is
|
||||
not available.
|
||||
"pt" will load the weights in the pytorch bin format.
|
||||
"safetensors" will load the weights in the safetensors format.
|
||||
"npcache" will load the weights in pytorch format and store
|
||||
a numpy cache to speed up the loading.
|
||||
"dummy" will initialize the weights with random values, which is
|
||||
mainly for profiling.
|
||||
dtype: Data type for model weights and activations. The "auto" option
|
||||
will use FP16 precision for FP32 and FP16 models, and BF16 precision
|
||||
for BF16 models.
|
||||
@@ -40,8 +47,7 @@ class ModelConfig:
|
||||
tokenizer_mode: str,
|
||||
trust_remote_code: bool,
|
||||
download_dir: Optional[str],
|
||||
use_np_weights: bool,
|
||||
use_dummy_weights: bool,
|
||||
load_format: str,
|
||||
dtype: str,
|
||||
seed: int,
|
||||
) -> None:
|
||||
@@ -50,14 +56,24 @@ class ModelConfig:
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
self.trust_remote_code = trust_remote_code
|
||||
self.download_dir = download_dir
|
||||
self.use_np_weights = use_np_weights
|
||||
self.use_dummy_weights = use_dummy_weights
|
||||
self.load_format = load_format
|
||||
self.seed = seed
|
||||
|
||||
self.hf_config = get_config(model, trust_remote_code)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
||||
self._verify_load_format()
|
||||
self._verify_tokenizer_mode()
|
||||
|
||||
def _verify_load_format(self) -> None:
|
||||
load_format = self.load_format.lower()
|
||||
if load_format not in [
|
||||
"auto", "pt", "safetensors", "npcache", "dummy"
|
||||
]:
|
||||
raise ValueError(
|
||||
f"Unknown load format: {self.load_format}. Must be one of "
|
||||
"'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
|
||||
self.load_format = load_format
|
||||
|
||||
def _verify_tokenizer_mode(self) -> None:
|
||||
tokenizer_mode = self.tokenizer_mode.lower()
|
||||
if tokenizer_mode not in ["auto", "slow"]:
|
||||
|
||||
Reference in New Issue
Block a user