[Core] Add reload_weights RPC method (#20096)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
@@ -460,11 +460,16 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
|
||||
{"load_config": {
|
||||
"load_format": original_load_format
|
||||
}})
|
||||
model_runner_2.load_model() # Load real weights inplace
|
||||
model_runner_2.reload_weights() # Load real weights inplace
|
||||
assert str(model_runner.get_model().state_dict()) == str(
|
||||
model_runner_2.get_model().state_dict())
|
||||
|
||||
|
||||
def test_reload_weights_before_load_model(model_runner):
|
||||
with pytest.raises(AssertionError):
|
||||
model_runner.reload_weights()
|
||||
|
||||
|
||||
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
|
||||
torch.set_default_dtype(torch.float16)
|
||||
layer_0 = "model.layers.0.self_attn.attn"
|
||||
|
||||
Reference in New Issue
Block a user