[Core] Add update_config RPC method (#20095)

Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
22quinn
2025-07-13 17:49:18 -07:00
committed by GitHub
parent 4bbfc36b16
commit 8632e831ba
7 changed files with 97 additions and 9 deletions

View File

@@ -434,16 +434,28 @@ def test_kv_cache_stride_order(monkeypatch, model_runner):
assert all(not kv.is_contiguous() for kv in model_runner.kv_caches)
def test_update_config(model_runner):
# Simple update
model_runner.update_config({"load_config": {"load_format": "dummy"}})
assert model_runner.load_config.load_format == "dummy"
# Raise error on non-existing config
with pytest.raises(AssertionError):
model_runner.update_config({"do_not_exist_config": "dummy"})
def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
# In this test, model_runner loads model + weights in one go, while
# model_runner_2 loads dummy weights first then load real weights inplace
model_runner.load_model()
original_load_format = model_runner_2.load_config.load_format
model_runner_2.load_config.load_format = "dummy"
model_runner_2.update_config({"load_config": {"load_format": "dummy"}})
model_runner_2.load_model() # Initial model loading with dummy weights
assert str(model_runner.get_model().state_dict()) != str(
model_runner_2.get_model().state_dict())
model_runner_2.load_config.load_format = original_load_format
model_runner_2.update_config(
{"load_config": {
"load_format": original_load_format
}})
model_runner_2.load_model() # Load real weights inplace
assert str(model_runner.get_model().state_dict()) == str(
model_runner_2.get_model().state_dict())