[QeRL] Compose online quantization with quantized reloading (#38032)

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
Kyle Sayers
2026-03-27 16:22:33 -04:00
committed by GitHub
parent 7ba425e916
commit 648edcf729
10 changed files with 184 additions and 260 deletions

View File

@@ -148,3 +148,60 @@ def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner):
mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0]
add_perp = llm.generate_prompt_perplexity(["3 4 = 7"], mask=["3 4 ="])[0]
assert add_perp < mul_perp
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize(
"base_model,mul_model,add_model,quantization",
[
(
"Qwen/Qwen3-0.6B",
"inference-optimization/Qwen3-0.6B-debug-multiply",
"inference-optimization/Qwen3-0.6B-debug-add",
"fp8",
),
(
"inference-optimization/DeepSeek-V3-debug-empty",
"inference-optimization/DeepSeek-V3-debug-multiply",
"inference-optimization/DeepSeek-V3-debug-add",
"fp8",
),
(
"Qwen/Qwen3-0.6B",
"inference-optimization/Qwen3-0.6B-debug-multiply",
"inference-optimization/Qwen3-0.6B-debug-add",
"mxfp8",
),
# ( TODO: support mxfp4 & mla
# "inference-optimization/DeepSeek-V3-debug-empty",
# "inference-optimization/DeepSeek-V3-debug-multiply",
# "inference-optimization/DeepSeek-V3-debug-add",
# "mxfp8",
# ),
],
)
def test_online_quantize_reload(
base_model, mul_model, add_model, quantization, tp_size, vllm_runner
):
if cuda_device_count_stateless() < tp_size:
pytest.skip(reason="Not enough CUDA devices")
if quantization == "fp8" and not current_platform.supports_fp8():
pytest.skip(reason="Requires FP8 support")
with vllm_runner(
model_name=base_model,
quantization=quantization,
tensor_parallel_size=tp_size,
enable_expert_parallel=(tp_size > 1 and "DeepSeek" in base_model),
enable_prefix_caching=False,
) as llm:
llm.collective_rpc("reload_weights", kwargs={"weights_path": mul_model})
mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0]
add_perp = llm.generate_prompt_perplexity(["3 4 = 7"], mask=["3 4 ="])[0]
assert mul_perp < add_perp
llm.collective_rpc("reload_weights", kwargs={"weights_path": add_model})
mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0]
add_perp = llm.generate_prompt_perplexity(["3 4 = 7"], mask=["3 4 ="])[0]
assert add_perp < mul_perp