[QeRL] Compose online quantization with quantized reloading (#38032)
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
@@ -148,3 +148,60 @@ def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner):
|
||||
mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0]
|
||||
add_perp = llm.generate_prompt_perplexity(["3 4 = 7"], mask=["3 4 ="])[0]
|
||||
assert add_perp < mul_perp
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tp_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"base_model,mul_model,add_model,quantization",
|
||||
[
|
||||
(
|
||||
"Qwen/Qwen3-0.6B",
|
||||
"inference-optimization/Qwen3-0.6B-debug-multiply",
|
||||
"inference-optimization/Qwen3-0.6B-debug-add",
|
||||
"fp8",
|
||||
),
|
||||
(
|
||||
"inference-optimization/DeepSeek-V3-debug-empty",
|
||||
"inference-optimization/DeepSeek-V3-debug-multiply",
|
||||
"inference-optimization/DeepSeek-V3-debug-add",
|
||||
"fp8",
|
||||
),
|
||||
(
|
||||
"Qwen/Qwen3-0.6B",
|
||||
"inference-optimization/Qwen3-0.6B-debug-multiply",
|
||||
"inference-optimization/Qwen3-0.6B-debug-add",
|
||||
"mxfp8",
|
||||
),
|
||||
# ( TODO: support mxfp4 & mla
|
||||
# "inference-optimization/DeepSeek-V3-debug-empty",
|
||||
# "inference-optimization/DeepSeek-V3-debug-multiply",
|
||||
# "inference-optimization/DeepSeek-V3-debug-add",
|
||||
# "mxfp8",
|
||||
# ),
|
||||
],
|
||||
)
|
||||
def test_online_quantize_reload(
|
||||
base_model, mul_model, add_model, quantization, tp_size, vllm_runner
|
||||
):
|
||||
if cuda_device_count_stateless() < tp_size:
|
||||
pytest.skip(reason="Not enough CUDA devices")
|
||||
|
||||
if quantization == "fp8" and not current_platform.supports_fp8():
|
||||
pytest.skip(reason="Requires FP8 support")
|
||||
|
||||
with vllm_runner(
|
||||
model_name=base_model,
|
||||
quantization=quantization,
|
||||
tensor_parallel_size=tp_size,
|
||||
enable_expert_parallel=(tp_size > 1 and "DeepSeek" in base_model),
|
||||
enable_prefix_caching=False,
|
||||
) as llm:
|
||||
llm.collective_rpc("reload_weights", kwargs={"weights_path": mul_model})
|
||||
mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0]
|
||||
add_perp = llm.generate_prompt_perplexity(["3 4 = 7"], mask=["3 4 ="])[0]
|
||||
assert mul_perp < add_perp
|
||||
|
||||
llm.collective_rpc("reload_weights", kwargs={"weights_path": add_model})
|
||||
mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0]
|
||||
add_perp = llm.generate_prompt_perplexity(["3 4 = 7"], mask=["3 4 ="])[0]
|
||||
assert add_perp < mul_perp
|
||||
|
||||
Reference in New Issue
Block a user