[QeRL] Fix online quantized reloading (#38442)

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
Kyle Sayers
2026-03-29 16:56:41 -04:00
committed by GitHub
parent 995dea1354
commit d28d86e8a3
9 changed files with 104 additions and 62 deletions

View File

@@ -38,7 +38,10 @@ def test_move_metatensors():
def test_reload_lifecycle():
layer = torch.nn.Linear(2, 3)
info = LayerReloadingInfo(restore_metadata=capture_layer_to_meta(layer))
info = LayerReloadingInfo(
restore_metadata=capture_layer_to_meta(layer),
restore_device=torch.device("cpu"),
)
restore_layer_on_meta(layer, info)
for name, tensor in get_layer_tensors(layer).items():
@@ -48,7 +51,7 @@ def test_reload_lifecycle():
assert tensor.__class__ == meta_tensor.__class__
assert tensor.__dict__ == meta_tensor.__dict__
materialize_layer(layer)
materialize_layer(layer, info)
for name, tensor in get_layer_tensors(layer).items():
materialized_tensor = getattr(layer, name)
assert tensor.dtype == materialized_tensor.dtype
@@ -60,7 +63,10 @@ def test_reload_lifecycle():
def test_model_cleanup(dist_init, default_vllm_config):
layer = QKVParallelLinear(2, 3, 4)
assert layer.weight.weight_loader.__self__ is layer
info = LayerReloadingInfo(restore_metadata=capture_layer_to_meta(layer))
info = LayerReloadingInfo(
restore_metadata=capture_layer_to_meta(layer),
restore_device=torch.device("cpu"),
)
mock_info_dict: WeakKeyDictionary[torch.nn.Module, LayerReloadingInfo] = (
WeakKeyDictionary()
@@ -90,39 +96,46 @@ def test_get_numel_loaded():
assert ret == "value"
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize(
"tp_size", [pytest.param(1), pytest.param(2, marks=[pytest.mark.slow_test])]
)
@pytest.mark.parametrize(
"base_model,mul_model,add_model",
[
(
pytest.param(
"Qwen/Qwen3-0.6B",
"inference-optimization/Qwen3-0.6B-debug-multiply",
"inference-optimization/Qwen3-0.6B-debug-add",
marks=[pytest.mark.slow_test],
),
(
pytest.param(
"inference-optimization/Qwen3-0.6B-FP8_BLOCK",
"inference-optimization/Qwen3-0.6B-debug-multiply-FP8_BLOCK",
"inference-optimization/Qwen3-0.6B-debug-add-FP8_BLOCK",
marks=[pytest.mark.slow_test],
),
(
pytest.param(
"inference-optimization/Qwen3-0.6B-W4A16-G128",
"inference-optimization/Qwen3-0.6B-debug-multiply-W4A16-G128",
"inference-optimization/Qwen3-0.6B-debug-add-W4A16-G128",
marks=[pytest.mark.slow_test],
),
(
pytest.param(
"inference-optimization/DeepSeek-V3-debug-empty",
"inference-optimization/DeepSeek-V3-debug-multiply",
"inference-optimization/DeepSeek-V3-debug-add",
marks=[pytest.mark.slow_test],
),
(
pytest.param(
"inference-optimization/DeepSeek-V3-debug-empty-FP8_DYNAMIC",
"inference-optimization/DeepSeek-V3-debug-multiply-FP8_DYNAMIC",
"inference-optimization/DeepSeek-V3-debug-add-FP8_DYNAMIC",
),
(
pytest.param(
"inference-optimization/DeepSeek-V3-debug-empty-NVFP4A16",
"inference-optimization/DeepSeek-V3-debug-multiply-NVFP4A16",
"inference-optimization/DeepSeek-V3-debug-add-NVFP4A16",
marks=[pytest.mark.slow_test],
),
],
)
@@ -138,6 +151,8 @@ def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner):
tensor_parallel_size=tp_size,
enable_expert_parallel=(tp_size > 1 and "DeepSeek" in base_model),
enable_prefix_caching=False,
max_model_len=16,
max_num_seqs=1,
) as llm:
llm.collective_rpc("reload_weights", kwargs={"weights_path": mul_model})
mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0]
@@ -150,34 +165,42 @@ def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner):
assert add_perp < mul_perp
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize(
"tp_size", [pytest.param(1), pytest.param(2, marks=[pytest.mark.slow_test])]
)
@pytest.mark.parametrize(
"base_model,mul_model,add_model,quantization",
[
(
pytest.param(
"Qwen/Qwen3-0.6B",
"inference-optimization/Qwen3-0.6B-debug-multiply",
"inference-optimization/Qwen3-0.6B-debug-add",
"fp8",
),
(
pytest.param(
"inference-optimization/DeepSeek-V3-debug-empty",
"inference-optimization/DeepSeek-V3-debug-multiply",
"inference-optimization/DeepSeek-V3-debug-add",
"fp8",
marks=[pytest.mark.slow_test],
),
(
pytest.param(
"Qwen/Qwen3-0.6B",
"inference-optimization/Qwen3-0.6B-debug-multiply",
"inference-optimization/Qwen3-0.6B-debug-add",
"mxfp8",
marks=[pytest.mark.slow_test],
),
pytest.param(
"inference-optimization/DeepSeek-V3-debug-empty",
"inference-optimization/DeepSeek-V3-debug-multiply",
"inference-optimization/DeepSeek-V3-debug-add",
"mxfp8",
marks=[
pytest.mark.slow_test,
pytest.mark.xfail(reason="mxfp4 & mla is not supported yet"),
],
),
# ( TODO: support mxfp4 & mla
# "inference-optimization/DeepSeek-V3-debug-empty",
# "inference-optimization/DeepSeek-V3-debug-multiply",
# "inference-optimization/DeepSeek-V3-debug-add",
# "mxfp8",
# ),
],
)
def test_online_quantize_reload(
@@ -195,6 +218,8 @@ def test_online_quantize_reload(
tensor_parallel_size=tp_size,
enable_expert_parallel=(tp_size > 1 and "DeepSeek" in base_model),
enable_prefix_caching=False,
max_model_len=16,
max_num_seqs=1,
) as llm:
llm.collective_rpc("reload_weights", kwargs={"weights_path": mul_model})
mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0]