Add blacklist in model checkpoint (#1325)
This commit is contained in:
@@ -144,8 +144,18 @@ def prepare_hf_model_weights(
|
|||||||
for pattern in allow_patterns:
|
for pattern in allow_patterns:
|
||||||
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
||||||
if not use_safetensors:
|
if not use_safetensors:
|
||||||
|
# Exclude files that are not needed for inference.
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
|
||||||
|
blacklist = [
|
||||||
|
"training_args.bin",
|
||||||
|
"optimizer.bin",
|
||||||
|
"optimizer.pt",
|
||||||
|
"scheduler.pt",
|
||||||
|
"scaler.pt",
|
||||||
|
]
|
||||||
hf_weights_files = [
|
hf_weights_files = [
|
||||||
x for x in hf_weights_files if not x.endswith("training_args.bin")
|
f for f in hf_weights_files
|
||||||
|
if not any(f.endswith(x) for x in blacklist)
|
||||||
]
|
]
|
||||||
|
|
||||||
if len(hf_weights_files) == 0 and use_safetensors and fall_back_to_pt:
|
if len(hf_weights_files) == 0 and use_safetensors and fall_back_to_pt:
|
||||||
|
|||||||
Reference in New Issue
Block a user