diff --git a/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py b/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py index cd13aca7e..255bca444 100644 --- a/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py +++ b/tests/entrypoints/weight_transfer/test_weight_transfer_llm.py @@ -90,6 +90,10 @@ class MockWeightTransferEngine(WeightTransferEngine[MockInitInfo, MockUpdateInfo def shutdown(self) -> None: MockWeightTransferEngine.shutdown_called = True + def trainer_send_weights(self, *args, **kwargs): + """Mock method to simulate trainer sending weights.""" + pass + def mock_create_engine(config, parallel_config): """Mock factory function that returns our mock engine."""