[QeRL] Fix online quantized reloading (#38442)

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
Kyle Sayers
2026-03-29 16:56:41 -04:00
committed by GitHub
parent 995dea1354
commit d28d86e8a3
9 changed files with 104 additions and 62 deletions

View File

@@ -812,7 +812,7 @@ steps:
commands:
- apt-get update && apt-get install -y curl libsodium23
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s model_executor
- pytest -v -s model_executor -m '(not slow_test)'
- pytest -v -s entrypoints/openai/completion/test_tensorizer_entrypoint.py
@@ -1242,7 +1242,7 @@ steps:
- vllm/platforms/rocm.py
commands:
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)'
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s model_executor/model_loader/test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s model_executor/model_loader/test_sharded_state_loader.py -m '(not slow_test)'
- pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/language -v -s -m 'distributed(num_gpus=2)'
- pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py
@@ -2501,7 +2501,7 @@ steps:
- tests/models/
commands:
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)'
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s model_executor/model_loader/test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s model_executor/model_loader/test_sharded_state_loader.py -m '(not slow_test)'
- pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/language -v -s -m 'distributed(num_gpus=2)'
- pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py

View File

@@ -13,5 +13,5 @@ steps:
commands:
- apt-get update && apt-get install -y curl libsodium23
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
- pytest -v -s model_executor
- pytest -v -s model_executor -m '(not slow_test)'
- pytest -v -s entrypoints/openai/completion/test_tensorizer_entrypoint.py

View File

@@ -14,7 +14,7 @@ steps:
- tests/models/
commands:
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)'
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s model_executor/model_loader/test_sharded_state_loader.py
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s model_executor/model_loader/test_sharded_state_loader.py -m '(not slow_test)'
# Avoid importing model tests that cause CUDA reinitialization error
- pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/language -v -s -m 'distributed(num_gpus=2)'

View File

@@ -38,7 +38,10 @@ def test_move_metatensors():
def test_reload_lifecycle():
layer = torch.nn.Linear(2, 3)
info = LayerReloadingInfo(restore_metadata=capture_layer_to_meta(layer))
info = LayerReloadingInfo(
restore_metadata=capture_layer_to_meta(layer),
restore_device=torch.device("cpu"),
)
restore_layer_on_meta(layer, info)
for name, tensor in get_layer_tensors(layer).items():
@@ -48,7 +51,7 @@ def test_reload_lifecycle():
assert tensor.__class__ == meta_tensor.__class__
assert tensor.__dict__ == meta_tensor.__dict__
materialize_layer(layer)
materialize_layer(layer, info)
for name, tensor in get_layer_tensors(layer).items():
materialized_tensor = getattr(layer, name)
assert tensor.dtype == materialized_tensor.dtype
@@ -60,7 +63,10 @@ def test_reload_lifecycle():
def test_model_cleanup(dist_init, default_vllm_config):
layer = QKVParallelLinear(2, 3, 4)
assert layer.weight.weight_loader.__self__ is layer
info = LayerReloadingInfo(restore_metadata=capture_layer_to_meta(layer))
info = LayerReloadingInfo(
restore_metadata=capture_layer_to_meta(layer),
restore_device=torch.device("cpu"),
)
mock_info_dict: WeakKeyDictionary[torch.nn.Module, LayerReloadingInfo] = (
WeakKeyDictionary()
@@ -90,39 +96,46 @@ def test_get_numel_loaded():
assert ret == "value"
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize(
"tp_size", [pytest.param(1), pytest.param(2, marks=[pytest.mark.slow_test])]
)
@pytest.mark.parametrize(
"base_model,mul_model,add_model",
[
(
pytest.param(
"Qwen/Qwen3-0.6B",
"inference-optimization/Qwen3-0.6B-debug-multiply",
"inference-optimization/Qwen3-0.6B-debug-add",
marks=[pytest.mark.slow_test],
),
(
pytest.param(
"inference-optimization/Qwen3-0.6B-FP8_BLOCK",
"inference-optimization/Qwen3-0.6B-debug-multiply-FP8_BLOCK",
"inference-optimization/Qwen3-0.6B-debug-add-FP8_BLOCK",
marks=[pytest.mark.slow_test],
),
(
pytest.param(
"inference-optimization/Qwen3-0.6B-W4A16-G128",
"inference-optimization/Qwen3-0.6B-debug-multiply-W4A16-G128",
"inference-optimization/Qwen3-0.6B-debug-add-W4A16-G128",
marks=[pytest.mark.slow_test],
),
(
pytest.param(
"inference-optimization/DeepSeek-V3-debug-empty",
"inference-optimization/DeepSeek-V3-debug-multiply",
"inference-optimization/DeepSeek-V3-debug-add",
marks=[pytest.mark.slow_test],
),
(
pytest.param(
"inference-optimization/DeepSeek-V3-debug-empty-FP8_DYNAMIC",
"inference-optimization/DeepSeek-V3-debug-multiply-FP8_DYNAMIC",
"inference-optimization/DeepSeek-V3-debug-add-FP8_DYNAMIC",
),
(
pytest.param(
"inference-optimization/DeepSeek-V3-debug-empty-NVFP4A16",
"inference-optimization/DeepSeek-V3-debug-multiply-NVFP4A16",
"inference-optimization/DeepSeek-V3-debug-add-NVFP4A16",
marks=[pytest.mark.slow_test],
),
],
)
@@ -138,6 +151,8 @@ def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner):
tensor_parallel_size=tp_size,
enable_expert_parallel=(tp_size > 1 and "DeepSeek" in base_model),
enable_prefix_caching=False,
max_model_len=16,
max_num_seqs=1,
) as llm:
llm.collective_rpc("reload_weights", kwargs={"weights_path": mul_model})
mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0]
@@ -150,34 +165,42 @@ def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner):
assert add_perp < mul_perp
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize(
"tp_size", [pytest.param(1), pytest.param(2, marks=[pytest.mark.slow_test])]
)
@pytest.mark.parametrize(
"base_model,mul_model,add_model,quantization",
[
(
pytest.param(
"Qwen/Qwen3-0.6B",
"inference-optimization/Qwen3-0.6B-debug-multiply",
"inference-optimization/Qwen3-0.6B-debug-add",
"fp8",
),
(
pytest.param(
"inference-optimization/DeepSeek-V3-debug-empty",
"inference-optimization/DeepSeek-V3-debug-multiply",
"inference-optimization/DeepSeek-V3-debug-add",
"fp8",
marks=[pytest.mark.slow_test],
),
(
pytest.param(
"Qwen/Qwen3-0.6B",
"inference-optimization/Qwen3-0.6B-debug-multiply",
"inference-optimization/Qwen3-0.6B-debug-add",
"mxfp8",
marks=[pytest.mark.slow_test],
),
pytest.param(
"inference-optimization/DeepSeek-V3-debug-empty",
"inference-optimization/DeepSeek-V3-debug-multiply",
"inference-optimization/DeepSeek-V3-debug-add",
"mxfp8",
marks=[
pytest.mark.slow_test,
pytest.mark.xfail(reason="mxfp4 & mla is not supported yet"),
],
),
# ( TODO: support mxfp4 & mla
# "inference-optimization/DeepSeek-V3-debug-empty",
# "inference-optimization/DeepSeek-V3-debug-multiply",
# "inference-optimization/DeepSeek-V3-debug-add",
# "mxfp8",
# ),
],
)
def test_online_quantize_reload(
@@ -195,6 +218,8 @@ def test_online_quantize_reload(
tensor_parallel_size=tp_size,
enable_expert_parallel=(tp_size > 1 and "DeepSeek" in base_model),
enable_prefix_caching=False,
max_model_len=16,
max_num_seqs=1,
) as llm:
llm.collective_rpc("reload_weights", kwargs={"weights_path": mul_model})
mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0]

View File

@@ -1006,14 +1006,17 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
initialize_online_processing(layer)
def process_weights_after_loading(self, layer: Module) -> None:
# TODO(@ksayers): inplace fp8 quant kernel, initialize scales with ones
if getattr(layer, "_already_called_process_weights_after_loading", False):
return
fp8_dtype = current_platform.fp8_dtype()
w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
w13_scale = torch.ones(layer.num_experts, dtype=torch.float32)
w2_scale = torch.ones(layer.num_experts, dtype=torch.float32)
w13_scale = torch.ones(
layer.num_experts, device=w13.device, dtype=torch.float32
)
w2_scale = torch.ones(layer.num_experts, device=w2.device, dtype=torch.float32)
layer.w13_input_scale = None
layer.w2_input_scale = None

View File

@@ -49,7 +49,10 @@ def get_layerwise_info(layer: torch.nn.Module) -> LayerReloadingInfo:
information existed, a new entry is constructed
"""
if layer not in LAYERWISE_INFO:
LAYERWISE_INFO[layer] = LayerReloadingInfo()
LAYERWISE_INFO[layer] = LayerReloadingInfo(
restore_metadata=({}, {}),
restore_device=torch.get_default_device(),
)
return LAYERWISE_INFO[layer]
@@ -64,6 +67,7 @@ def record_metadata_for_reloading(model: torch.nn.Module):
for layer in model.modules():
info = get_layerwise_info(layer)
info.restore_metadata = capture_layer_to_meta(layer)
info.restore_device = torch.get_default_device()
@torch.no_grad()
@@ -99,10 +103,18 @@ def initialize_layerwise_reload(model: torch.nn.Module):
# Restore layer parameters/buffers onto meta device
restore_layer_on_meta(layer, info)
# Wrap weight loaders to buffer loading
initialize_online_processing(layer)
def initialize_online_processing(layer: torch.nn.Module):
"""
Wrap a layer's weight loaders with online processing loaders.
Called by either `initialize_layerwise_reload` or an online quantization scheme,
prevents double wrapping in the case of online quantization + reloading
:param layer: layer whose parameter weight loaders will be wrapped
"""
info = get_layerwise_info(layer)
# Track loading progress to determine when to process/copy
@@ -211,7 +223,7 @@ def finalize_layerwise_processing(model: torch.nn.Module, model_config: ModelCon
elif info.load_numel <= 0:
# first load but received no weights. This happens on dummy load
if info.kernel_tensors is None:
materialize_layer(layer)
materialize_layer(layer, info)
# reloading: place kernel tensors back as a fallback
else:
@@ -244,7 +256,7 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
4. Copies processed values back to original tensor storage
"""
# Materialize layer tensors onto device
materialize_layer(layer)
materialize_layer(layer, info)
# Reset online quantization flag so process_weights_after_loading
# will run again during reload

View File

@@ -94,14 +94,15 @@ def restore_layer_on_meta(layer: torch.nn.Module, info: LayerReloadingInfo):
layer.register_buffer(name, buffer)
def materialize_layer(layer: torch.nn.Module) -> None:
def materialize_layer(layer: torch.nn.Module, info: LayerReloadingInfo):
"""Materialize all meta tensors in a layer to actual tensors."""
if layer.__class__.__name__ in SKIP_MODULES:
return
for name, tensor in get_layer_tensors(layer).items():
if name not in SKIP_TENSORS:
setattr(layer, name, materialize_meta_tensor(tensor))
with info.restore_device:
for name, tensor in get_layer_tensors(layer).items():
if name not in SKIP_TENSORS:
setattr(layer, name, materialize_meta_tensor(tensor))
class CopyCounter(TorchDispatchMode):

View File

@@ -13,21 +13,26 @@ LayerTensors = tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]
@dataclass
class LayerReloadingInfo:
# model format (meta), populated by `record_metadata_for_reloading`
restore_metadata: LayerTensors = field(default_factory=lambda: ({}, {}))
# model format metadata, recorded by `record_metadata_for_reloading`
restore_metadata: LayerTensors
# kernel format (device), used to copy into when reloading only
kernel_tensors: LayerTensors | None = None
# device to materialize layers with, recorded by `record_metadata_for_reloading`
restore_device: torch.device
# track how many restored elements are ready for loading
# track how many elements are ready for loading, used by `online_process_loader`
load_numel: int = 0
load_numel_total: int | None = None
# stores arguments and tensors ready for loading
# used by `online_process_loader` to buffer args and tensors until ready to load
loaded_weights: list[tuple[str, BoundArguments]] = field(default_factory=list)
# kernel formatted tensors, copied into by `_layerwise_process` when reloading
kernel_tensors: LayerTensors | None = None
def reset(self):
self.__init__(restore_metadata=self.restore_metadata) # type: ignore[misc]
self.__init__( # type: ignore[misc]
restore_metadata=self.restore_metadata, restore_device=self.restore_device
)
def can_load(self) -> bool:
return self.load_numel_total is not None

View File

@@ -4943,28 +4943,24 @@ class GPUModelRunner(
# begin loading weights
logger.info_once("Reloading weights inplace...", scope="local")
load_device = (
self.vllm_config.load_config.device or self.vllm_config.device_config.device
)
with torch.device(load_device):
if is_checkpoint_format:
# load weights from checkpoint/ original model format
initialize_layerwise_reload(model)
loaded_weights = model.load_weights(weights_iterator)
finalize_layerwise_reload(model, self.model_config)
if is_checkpoint_format:
# load weights from checkpoint/ original model format
initialize_layerwise_reload(model)
loaded_weights = model.load_weights(weights_iterator)
finalize_layerwise_reload(model, self.model_config)
else:
# load weights from kernel format
logger.warning_once(
"Reloading with `is_checkpoint_format=True` requires that "
"weights be in kernel format and already sharded",
scope="local",
)
loaded_weights = set()
for name, loaded_weight in weights_iterator:
param = model.get_parameter(name) # TODO: buffers?
param.copy_(loaded_weight)
loaded_weights.add(name)
else:
# load weights from kernel format
logger.warning_once(
"Reloading with `is_checkpoint_format=True` requires that "
"weights be in kernel format and already sharded",
scope="local",
)
loaded_weights = set()
for name, loaded_weight in weights_iterator:
param = model.get_parameter(name) # TODO: buffers?
param.copy_(loaded_weight)
loaded_weights.add(name)
# logging and validation
counter_after_reloading = time.perf_counter()