diff --git a/tests/runai_model_streamer_test/test_weight_utils.py b/tests/runai_model_streamer_test/test_weight_utils.py index 4afa76c51..06e506c35 100644 --- a/tests/runai_model_streamer_test/test_weight_utils.py +++ b/tests/runai_model_streamer_test/test_weight_utils.py @@ -23,10 +23,11 @@ def test_runai_model_loader(): runai_model_streamer_tensors = {} hf_safetensors_tensors = {} - for name, tensor in runai_safetensors_weights_iterator(safetensors): + for name, tensor in runai_safetensors_weights_iterator( + safetensors, True): runai_model_streamer_tensors[name] = tensor - for name, tensor in safetensors_weights_iterator(safetensors): + for name, tensor in safetensors_weights_iterator(safetensors, True): hf_safetensors_tensors[name] = tensor assert len(runai_model_streamer_tensors) == len(hf_safetensors_tensors)