[Core] Add reload_weights RPC method (#20096)

Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
22quinn
2025-07-23 14:24:52 -07:00
committed by GitHub
parent 14bf19e39f
commit 5c9b807b34
5 changed files with 51 additions and 34 deletions

View File

@@ -460,11 +460,16 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
{"load_config": {
"load_format": original_load_format
}})
model_runner_2.load_model() # Load real weights inplace
model_runner_2.reload_weights() # Load real weights inplace
assert str(model_runner.get_model().state_dict()) == str(
model_runner_2.get_model().state_dict())
def test_reload_weights_before_load_model(model_runner):
with pytest.raises(AssertionError):
model_runner.reload_weights()
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
torch.set_default_dtype(torch.float16)
layer_0 = "model.layers.0.self_attn.attn"