[Core] Add update_config RPC method (#20095)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
@@ -434,16 +434,28 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
|
||||
assert all(not kv.is_contiguous() for kv in model_runner.kv_caches)
|
||||
|
||||
|
||||
def test_update_config(model_runner):
|
||||
# Simple update
|
||||
model_runner.update_config({"load_config": {"load_format": "dummy"}})
|
||||
assert model_runner.load_config.load_format == "dummy"
|
||||
# Raise error on non-existing config
|
||||
with pytest.raises(AssertionError):
|
||||
model_runner.update_config({"do_not_exist_config": "dummy"})
|
||||
|
||||
|
||||
def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
|
||||
# In this test, model_runner loads model + weights in one go, while
|
||||
# model_runner_2 loads dummy weights first then load real weights inplace
|
||||
model_runner.load_model()
|
||||
original_load_format = model_runner_2.load_config.load_format
|
||||
model_runner_2.load_config.load_format = "dummy"
|
||||
model_runner_2.update_config({"load_config": {"load_format": "dummy"}})
|
||||
model_runner_2.load_model() # Initial model loading with dummy weights
|
||||
assert str(model_runner.get_model().state_dict()) != str(
|
||||
model_runner_2.get_model().state_dict())
|
||||
model_runner_2.load_config.load_format = original_load_format
|
||||
model_runner_2.update_config(
|
||||
{"load_config": {
|
||||
"load_format": original_load_format
|
||||
}})
|
||||
model_runner_2.load_model() # Load real weights inplace
|
||||
assert str(model_runner.get_model().state_dict()) == str(
|
||||
model_runner_2.get_model().state_dict())
|
||||
|
||||
Reference in New Issue
Block a user