diff --git a/examples/offline_inference/save_sharded_state.py b/examples/offline_inference/save_sharded_state.py index d6b8b7e68..41d7a3492 100644 --- a/examples/offline_inference/save_sharded_state.py +++ b/examples/offline_inference/save_sharded_state.py @@ -47,7 +47,7 @@ def parse_args(): ) parser.add_argument( "--max-file-size", - type=str, + type=int, default=5 * 1024**3, help="max size (in bytes) of each safetensors file", )