2026-01-30 10:50:05 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
|
import gc
|
|
|
|
|
import inspect
|
|
|
|
|
from weakref import WeakKeyDictionary, ref
|
|
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
from vllm.model_executor.layers.linear import QKVParallelLinear
|
|
|
|
|
from vllm.model_executor.model_loader.reload.meta import (
|
|
|
|
|
capture_layer_to_meta,
|
|
|
|
|
get_numel_loaded,
|
|
|
|
|
materialize_layer,
|
|
|
|
|
materialize_meta_tensor,
|
|
|
|
|
restore_layer_on_meta,
|
|
|
|
|
to_meta_tensor,
|
|
|
|
|
)
|
|
|
|
|
from vllm.model_executor.model_loader.reload.types import LayerReloadingInfo
|
|
|
|
|
from vllm.model_executor.model_loader.reload.utils import get_layer_tensors
|
|
|
|
|
from vllm.platforms import current_platform
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_move_metatensors():
|
|
|
|
|
tensor = torch.empty((1, 2, 3))
|
|
|
|
|
meta_tensor = to_meta_tensor(tensor)
|
|
|
|
|
materialized_tensor = materialize_meta_tensor(meta_tensor)
|
|
|
|
|
|
|
|
|
|
assert meta_tensor.device.type == "meta"
|
|
|
|
|
assert tensor.device == materialized_tensor.device
|
|
|
|
|
|
|
|
|
|
assert tensor.dtype == meta_tensor.dtype == materialized_tensor.dtype
|
|
|
|
|
assert tensor.shape == meta_tensor.shape == materialized_tensor.shape
|
|
|
|
|
assert tensor.__class__ == meta_tensor.__class__ == materialized_tensor.__class__
|
|
|
|
|
assert tensor.__dict__ == meta_tensor.__dict__ == materialized_tensor.__dict__
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_reload_lifecycle():
|
|
|
|
|
layer = torch.nn.Linear(2, 3)
|
2026-03-29 16:56:41 -04:00
|
|
|
info = LayerReloadingInfo(
|
|
|
|
|
restore_metadata=capture_layer_to_meta(layer),
|
|
|
|
|
restore_device=torch.device("cpu"),
|
|
|
|
|
)
|
2026-01-30 10:50:05 -05:00
|
|
|
|
|
|
|
|
restore_layer_on_meta(layer, info)
|
|
|
|
|
for name, tensor in get_layer_tensors(layer).items():
|
|
|
|
|
meta_tensor = getattr(layer, name)
|
|
|
|
|
assert tensor.dtype == meta_tensor.dtype
|
|
|
|
|
assert tensor.shape == meta_tensor.shape
|
|
|
|
|
assert tensor.__class__ == meta_tensor.__class__
|
|
|
|
|
assert tensor.__dict__ == meta_tensor.__dict__
|
|
|
|
|
|
2026-03-29 16:56:41 -04:00
|
|
|
materialize_layer(layer, info)
|
2026-01-30 10:50:05 -05:00
|
|
|
for name, tensor in get_layer_tensors(layer).items():
|
|
|
|
|
materialized_tensor = getattr(layer, name)
|
|
|
|
|
assert tensor.dtype == materialized_tensor.dtype
|
|
|
|
|
assert tensor.shape == materialized_tensor.shape
|
|
|
|
|
assert tensor.__class__ == materialized_tensor.__class__
|
|
|
|
|
assert tensor.__dict__ == materialized_tensor.__dict__
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_model_cleanup(dist_init, default_vllm_config):
|
|
|
|
|
layer = QKVParallelLinear(2, 3, 4)
|
|
|
|
|
assert layer.weight.weight_loader.__self__ is layer
|
2026-03-29 16:56:41 -04:00
|
|
|
info = LayerReloadingInfo(
|
|
|
|
|
restore_metadata=capture_layer_to_meta(layer),
|
|
|
|
|
restore_device=torch.device("cpu"),
|
|
|
|
|
)
|
2026-01-30 10:50:05 -05:00
|
|
|
|
|
|
|
|
mock_info_dict: WeakKeyDictionary[torch.nn.Module, LayerReloadingInfo] = (
|
|
|
|
|
WeakKeyDictionary()
|
|
|
|
|
)
|
|
|
|
|
mock_info_dict[layer] = info
|
|
|
|
|
layer_ref = ref(layer)
|
|
|
|
|
|
|
|
|
|
del layer
|
|
|
|
|
gc.collect()
|
|
|
|
|
|
|
|
|
|
assert layer_ref() is None
|
|
|
|
|
assert len(mock_info_dict) == 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_get_numel_loaded():
|
|
|
|
|
param = torch.empty(10, device="meta")
|
|
|
|
|
loaded_weight = torch.empty(10)
|
|
|
|
|
|
|
|
|
|
def complex_weight_loader(param, loaded_weight):
|
|
|
|
|
param[:3] = loaded_weight[:3]
|
|
|
|
|
param[5:8] = loaded_weight[5:8]
|
|
|
|
|
return "value"
|
|
|
|
|
|
|
|
|
|
args = inspect.signature(complex_weight_loader).bind(param, loaded_weight)
|
|
|
|
|
num_loaded, ret = get_numel_loaded(complex_weight_loader, args)
|
|
|
|
|
assert num_loaded == 6
|
|
|
|
|
assert ret == "value"
|
|
|
|
|
|
|
|
|
|
|
2026-03-29 16:56:41 -04:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
|
"tp_size", [pytest.param(1), pytest.param(2, marks=[pytest.mark.slow_test])]
|
|
|
|
|
)
|
2026-01-30 10:50:05 -05:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
|
"base_model,mul_model,add_model",
|
|
|
|
|
[
|
2026-03-29 16:56:41 -04:00
|
|
|
pytest.param(
|
2026-01-30 10:50:05 -05:00
|
|
|
"Qwen/Qwen3-0.6B",
|
|
|
|
|
"inference-optimization/Qwen3-0.6B-debug-multiply",
|
|
|
|
|
"inference-optimization/Qwen3-0.6B-debug-add",
|
2026-03-29 16:56:41 -04:00
|
|
|
marks=[pytest.mark.slow_test],
|
2026-01-30 10:50:05 -05:00
|
|
|
),
|
2026-03-29 16:56:41 -04:00
|
|
|
pytest.param(
|
2026-01-30 10:50:05 -05:00
|
|
|
"inference-optimization/Qwen3-0.6B-FP8_BLOCK",
|
|
|
|
|
"inference-optimization/Qwen3-0.6B-debug-multiply-FP8_BLOCK",
|
|
|
|
|
"inference-optimization/Qwen3-0.6B-debug-add-FP8_BLOCK",
|
2026-03-29 16:56:41 -04:00
|
|
|
marks=[pytest.mark.slow_test],
|
2026-01-30 10:50:05 -05:00
|
|
|
),
|
2026-03-29 16:56:41 -04:00
|
|
|
pytest.param(
|
2026-01-30 10:50:05 -05:00
|
|
|
"inference-optimization/Qwen3-0.6B-W4A16-G128",
|
|
|
|
|
"inference-optimization/Qwen3-0.6B-debug-multiply-W4A16-G128",
|
|
|
|
|
"inference-optimization/Qwen3-0.6B-debug-add-W4A16-G128",
|
2026-03-29 16:56:41 -04:00
|
|
|
marks=[pytest.mark.slow_test],
|
2026-01-30 10:50:05 -05:00
|
|
|
),
|
2026-03-29 16:56:41 -04:00
|
|
|
pytest.param(
|
2026-01-30 10:50:05 -05:00
|
|
|
"inference-optimization/DeepSeek-V3-debug-empty",
|
|
|
|
|
"inference-optimization/DeepSeek-V3-debug-multiply",
|
|
|
|
|
"inference-optimization/DeepSeek-V3-debug-add",
|
2026-03-29 16:56:41 -04:00
|
|
|
marks=[pytest.mark.slow_test],
|
2026-01-30 10:50:05 -05:00
|
|
|
),
|
2026-03-29 16:56:41 -04:00
|
|
|
pytest.param(
|
2026-01-30 10:50:05 -05:00
|
|
|
"inference-optimization/DeepSeek-V3-debug-empty-FP8_DYNAMIC",
|
|
|
|
|
"inference-optimization/DeepSeek-V3-debug-multiply-FP8_DYNAMIC",
|
|
|
|
|
"inference-optimization/DeepSeek-V3-debug-add-FP8_DYNAMIC",
|
|
|
|
|
),
|
2026-03-29 16:56:41 -04:00
|
|
|
pytest.param(
|
2026-01-30 10:50:05 -05:00
|
|
|
"inference-optimization/DeepSeek-V3-debug-empty-NVFP4A16",
|
|
|
|
|
"inference-optimization/DeepSeek-V3-debug-multiply-NVFP4A16",
|
|
|
|
|
"inference-optimization/DeepSeek-V3-debug-add-NVFP4A16",
|
2026-03-29 16:56:41 -04:00
|
|
|
marks=[pytest.mark.slow_test],
|
2026-01-30 10:50:05 -05:00
|
|
|
),
|
|
|
|
|
],
|
|
|
|
|
)
|
|
|
|
|
def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner):
|
2026-03-31 07:32:54 -07:00
|
|
|
if current_platform.device_count() < tp_size:
|
2026-01-30 10:50:05 -05:00
|
|
|
pytest.skip(reason="Not enough CUDA devices")
|
|
|
|
|
|
|
|
|
|
if "FP8" in base_model and not current_platform.supports_fp8():
|
|
|
|
|
pytest.skip(reason="Requires FP8 support")
|
|
|
|
|
|
|
|
|
|
with vllm_runner(
|
|
|
|
|
model_name=base_model,
|
|
|
|
|
tensor_parallel_size=tp_size,
|
|
|
|
|
enable_expert_parallel=(tp_size > 1 and "DeepSeek" in base_model),
|
|
|
|
|
enable_prefix_caching=False,
|
2026-03-29 16:56:41 -04:00
|
|
|
max_model_len=16,
|
|
|
|
|
max_num_seqs=1,
|
2026-01-30 10:50:05 -05:00
|
|
|
) as llm:
|
|
|
|
|
llm.collective_rpc("reload_weights", kwargs={"weights_path": mul_model})
|
|
|
|
|
mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0]
|
|
|
|
|
add_perp = llm.generate_prompt_perplexity(["3 4 = 7"], mask=["3 4 ="])[0]
|
|
|
|
|
assert mul_perp < add_perp
|
|
|
|
|
|
|
|
|
|
llm.collective_rpc("reload_weights", kwargs={"weights_path": add_model})
|
|
|
|
|
mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0]
|
|
|
|
|
add_perp = llm.generate_prompt_perplexity(["3 4 = 7"], mask=["3 4 ="])[0]
|
|
|
|
|
assert add_perp < mul_perp
|
2026-03-27 16:22:33 -04:00
|
|
|
|
|
|
|
|
|
2026-03-29 16:56:41 -04:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
|
"tp_size", [pytest.param(1), pytest.param(2, marks=[pytest.mark.slow_test])]
|
|
|
|
|
)
|
2026-03-27 16:22:33 -04:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
|
"base_model,mul_model,add_model,quantization",
|
|
|
|
|
[
|
2026-03-29 16:56:41 -04:00
|
|
|
pytest.param(
|
2026-03-27 16:22:33 -04:00
|
|
|
"Qwen/Qwen3-0.6B",
|
|
|
|
|
"inference-optimization/Qwen3-0.6B-debug-multiply",
|
|
|
|
|
"inference-optimization/Qwen3-0.6B-debug-add",
|
|
|
|
|
"fp8",
|
|
|
|
|
),
|
2026-03-29 16:56:41 -04:00
|
|
|
pytest.param(
|
2026-03-27 16:22:33 -04:00
|
|
|
"inference-optimization/DeepSeek-V3-debug-empty",
|
|
|
|
|
"inference-optimization/DeepSeek-V3-debug-multiply",
|
|
|
|
|
"inference-optimization/DeepSeek-V3-debug-add",
|
|
|
|
|
"fp8",
|
2026-03-29 16:56:41 -04:00
|
|
|
marks=[pytest.mark.slow_test],
|
2026-03-27 16:22:33 -04:00
|
|
|
),
|
2026-03-29 16:56:41 -04:00
|
|
|
pytest.param(
|
2026-03-27 16:22:33 -04:00
|
|
|
"Qwen/Qwen3-0.6B",
|
|
|
|
|
"inference-optimization/Qwen3-0.6B-debug-multiply",
|
|
|
|
|
"inference-optimization/Qwen3-0.6B-debug-add",
|
|
|
|
|
"mxfp8",
|
2026-03-29 16:56:41 -04:00
|
|
|
marks=[pytest.mark.slow_test],
|
|
|
|
|
),
|
|
|
|
|
pytest.param(
|
|
|
|
|
"inference-optimization/DeepSeek-V3-debug-empty",
|
|
|
|
|
"inference-optimization/DeepSeek-V3-debug-multiply",
|
|
|
|
|
"inference-optimization/DeepSeek-V3-debug-add",
|
|
|
|
|
"mxfp8",
|
|
|
|
|
marks=[
|
|
|
|
|
pytest.mark.slow_test,
|
|
|
|
|
pytest.mark.xfail(reason="mxfp4 & mla is not supported yet"),
|
|
|
|
|
],
|
2026-03-27 16:22:33 -04:00
|
|
|
),
|
|
|
|
|
],
|
|
|
|
|
)
|
|
|
|
|
def test_online_quantize_reload(
|
|
|
|
|
base_model, mul_model, add_model, quantization, tp_size, vllm_runner
|
|
|
|
|
):
|
2026-03-31 07:32:54 -07:00
|
|
|
if current_platform.device_count() < tp_size:
|
|
|
|
|
pytest.skip(reason="Not enough GPU devices")
|
2026-03-27 16:22:33 -04:00
|
|
|
|
|
|
|
|
if quantization == "fp8" and not current_platform.supports_fp8():
|
|
|
|
|
pytest.skip(reason="Requires FP8 support")
|
|
|
|
|
|
|
|
|
|
with vllm_runner(
|
|
|
|
|
model_name=base_model,
|
|
|
|
|
quantization=quantization,
|
|
|
|
|
tensor_parallel_size=tp_size,
|
|
|
|
|
enable_expert_parallel=(tp_size > 1 and "DeepSeek" in base_model),
|
|
|
|
|
enable_prefix_caching=False,
|
2026-03-29 16:56:41 -04:00
|
|
|
max_model_len=16,
|
|
|
|
|
max_num_seqs=1,
|
2026-03-27 16:22:33 -04:00
|
|
|
) as llm:
|
|
|
|
|
llm.collective_rpc("reload_weights", kwargs={"weights_path": mul_model})
|
|
|
|
|
mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0]
|
|
|
|
|
add_perp = llm.generate_prompt_perplexity(["3 4 = 7"], mask=["3 4 ="])[0]
|
|
|
|
|
assert mul_perp < add_perp
|
|
|
|
|
|
|
|
|
|
llm.collective_rpc("reload_weights", kwargs={"weights_path": add_model})
|
|
|
|
|
mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0]
|
|
|
|
|
add_perp = llm.generate_prompt_perplexity(["3 4 = 7"], mask=["3 4 ="])[0]
|
|
|
|
|
assert add_perp < mul_perp
|