[QeRL] Fix online quantized reloading (#38442)
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
@@ -812,7 +812,7 @@ steps:
|
||||
commands:
|
||||
- apt-get update && apt-get install -y curl libsodium23
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s model_executor
|
||||
- pytest -v -s model_executor -m '(not slow_test)'
|
||||
- pytest -v -s entrypoints/openai/completion/test_tensorizer_entrypoint.py
|
||||
|
||||
|
||||
@@ -1242,7 +1242,7 @@ steps:
|
||||
- vllm/platforms/rocm.py
|
||||
commands:
|
||||
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)'
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s model_executor/model_loader/test_sharded_state_loader.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s model_executor/model_loader/test_sharded_state_loader.py -m '(not slow_test)'
|
||||
- pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/language -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py
|
||||
@@ -2501,7 +2501,7 @@ steps:
|
||||
- tests/models/
|
||||
commands:
|
||||
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)'
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s model_executor/model_loader/test_sharded_state_loader.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s model_executor/model_loader/test_sharded_state_loader.py -m '(not slow_test)'
|
||||
- pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/language -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/multimodal -v -s -m 'distributed(num_gpus=2)' --ignore models/multimodal/generation/test_whisper.py
|
||||
|
||||
@@ -13,5 +13,5 @@ steps:
|
||||
commands:
|
||||
- apt-get update && apt-get install -y curl libsodium23
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s model_executor
|
||||
- pytest -v -s model_executor -m '(not slow_test)'
|
||||
- pytest -v -s entrypoints/openai/completion/test_tensorizer_entrypoint.py
|
||||
|
||||
@@ -14,7 +14,7 @@ steps:
|
||||
- tests/models/
|
||||
commands:
|
||||
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)'
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s model_executor/model_loader/test_sharded_state_loader.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s model_executor/model_loader/test_sharded_state_loader.py -m '(not slow_test)'
|
||||
# Avoid importing model tests that cause CUDA reinitialization error
|
||||
- pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)'
|
||||
- pytest models/language -v -s -m 'distributed(num_gpus=2)'
|
||||
|
||||
@@ -38,7 +38,10 @@ def test_move_metatensors():
|
||||
|
||||
def test_reload_lifecycle():
|
||||
layer = torch.nn.Linear(2, 3)
|
||||
info = LayerReloadingInfo(restore_metadata=capture_layer_to_meta(layer))
|
||||
info = LayerReloadingInfo(
|
||||
restore_metadata=capture_layer_to_meta(layer),
|
||||
restore_device=torch.device("cpu"),
|
||||
)
|
||||
|
||||
restore_layer_on_meta(layer, info)
|
||||
for name, tensor in get_layer_tensors(layer).items():
|
||||
@@ -48,7 +51,7 @@ def test_reload_lifecycle():
|
||||
assert tensor.__class__ == meta_tensor.__class__
|
||||
assert tensor.__dict__ == meta_tensor.__dict__
|
||||
|
||||
materialize_layer(layer)
|
||||
materialize_layer(layer, info)
|
||||
for name, tensor in get_layer_tensors(layer).items():
|
||||
materialized_tensor = getattr(layer, name)
|
||||
assert tensor.dtype == materialized_tensor.dtype
|
||||
@@ -60,7 +63,10 @@ def test_reload_lifecycle():
|
||||
def test_model_cleanup(dist_init, default_vllm_config):
|
||||
layer = QKVParallelLinear(2, 3, 4)
|
||||
assert layer.weight.weight_loader.__self__ is layer
|
||||
info = LayerReloadingInfo(restore_metadata=capture_layer_to_meta(layer))
|
||||
info = LayerReloadingInfo(
|
||||
restore_metadata=capture_layer_to_meta(layer),
|
||||
restore_device=torch.device("cpu"),
|
||||
)
|
||||
|
||||
mock_info_dict: WeakKeyDictionary[torch.nn.Module, LayerReloadingInfo] = (
|
||||
WeakKeyDictionary()
|
||||
@@ -90,39 +96,46 @@ def test_get_numel_loaded():
|
||||
assert ret == "value"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tp_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"tp_size", [pytest.param(1), pytest.param(2, marks=[pytest.mark.slow_test])]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"base_model,mul_model,add_model",
|
||||
[
|
||||
(
|
||||
pytest.param(
|
||||
"Qwen/Qwen3-0.6B",
|
||||
"inference-optimization/Qwen3-0.6B-debug-multiply",
|
||||
"inference-optimization/Qwen3-0.6B-debug-add",
|
||||
marks=[pytest.mark.slow_test],
|
||||
),
|
||||
(
|
||||
pytest.param(
|
||||
"inference-optimization/Qwen3-0.6B-FP8_BLOCK",
|
||||
"inference-optimization/Qwen3-0.6B-debug-multiply-FP8_BLOCK",
|
||||
"inference-optimization/Qwen3-0.6B-debug-add-FP8_BLOCK",
|
||||
marks=[pytest.mark.slow_test],
|
||||
),
|
||||
(
|
||||
pytest.param(
|
||||
"inference-optimization/Qwen3-0.6B-W4A16-G128",
|
||||
"inference-optimization/Qwen3-0.6B-debug-multiply-W4A16-G128",
|
||||
"inference-optimization/Qwen3-0.6B-debug-add-W4A16-G128",
|
||||
marks=[pytest.mark.slow_test],
|
||||
),
|
||||
(
|
||||
pytest.param(
|
||||
"inference-optimization/DeepSeek-V3-debug-empty",
|
||||
"inference-optimization/DeepSeek-V3-debug-multiply",
|
||||
"inference-optimization/DeepSeek-V3-debug-add",
|
||||
marks=[pytest.mark.slow_test],
|
||||
),
|
||||
(
|
||||
pytest.param(
|
||||
"inference-optimization/DeepSeek-V3-debug-empty-FP8_DYNAMIC",
|
||||
"inference-optimization/DeepSeek-V3-debug-multiply-FP8_DYNAMIC",
|
||||
"inference-optimization/DeepSeek-V3-debug-add-FP8_DYNAMIC",
|
||||
),
|
||||
(
|
||||
pytest.param(
|
||||
"inference-optimization/DeepSeek-V3-debug-empty-NVFP4A16",
|
||||
"inference-optimization/DeepSeek-V3-debug-multiply-NVFP4A16",
|
||||
"inference-optimization/DeepSeek-V3-debug-add-NVFP4A16",
|
||||
marks=[pytest.mark.slow_test],
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -138,6 +151,8 @@ def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner):
|
||||
tensor_parallel_size=tp_size,
|
||||
enable_expert_parallel=(tp_size > 1 and "DeepSeek" in base_model),
|
||||
enable_prefix_caching=False,
|
||||
max_model_len=16,
|
||||
max_num_seqs=1,
|
||||
) as llm:
|
||||
llm.collective_rpc("reload_weights", kwargs={"weights_path": mul_model})
|
||||
mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0]
|
||||
@@ -150,34 +165,42 @@ def test_reload_weights(base_model, mul_model, add_model, tp_size, vllm_runner):
|
||||
assert add_perp < mul_perp
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tp_size", [2])
|
||||
@pytest.mark.parametrize(
|
||||
"tp_size", [pytest.param(1), pytest.param(2, marks=[pytest.mark.slow_test])]
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"base_model,mul_model,add_model,quantization",
|
||||
[
|
||||
(
|
||||
pytest.param(
|
||||
"Qwen/Qwen3-0.6B",
|
||||
"inference-optimization/Qwen3-0.6B-debug-multiply",
|
||||
"inference-optimization/Qwen3-0.6B-debug-add",
|
||||
"fp8",
|
||||
),
|
||||
(
|
||||
pytest.param(
|
||||
"inference-optimization/DeepSeek-V3-debug-empty",
|
||||
"inference-optimization/DeepSeek-V3-debug-multiply",
|
||||
"inference-optimization/DeepSeek-V3-debug-add",
|
||||
"fp8",
|
||||
marks=[pytest.mark.slow_test],
|
||||
),
|
||||
(
|
||||
pytest.param(
|
||||
"Qwen/Qwen3-0.6B",
|
||||
"inference-optimization/Qwen3-0.6B-debug-multiply",
|
||||
"inference-optimization/Qwen3-0.6B-debug-add",
|
||||
"mxfp8",
|
||||
marks=[pytest.mark.slow_test],
|
||||
),
|
||||
pytest.param(
|
||||
"inference-optimization/DeepSeek-V3-debug-empty",
|
||||
"inference-optimization/DeepSeek-V3-debug-multiply",
|
||||
"inference-optimization/DeepSeek-V3-debug-add",
|
||||
"mxfp8",
|
||||
marks=[
|
||||
pytest.mark.slow_test,
|
||||
pytest.mark.xfail(reason="mxfp4 & mla is not supported yet"),
|
||||
],
|
||||
),
|
||||
# ( TODO: support mxfp4 & mla
|
||||
# "inference-optimization/DeepSeek-V3-debug-empty",
|
||||
# "inference-optimization/DeepSeek-V3-debug-multiply",
|
||||
# "inference-optimization/DeepSeek-V3-debug-add",
|
||||
# "mxfp8",
|
||||
# ),
|
||||
],
|
||||
)
|
||||
def test_online_quantize_reload(
|
||||
@@ -195,6 +218,8 @@ def test_online_quantize_reload(
|
||||
tensor_parallel_size=tp_size,
|
||||
enable_expert_parallel=(tp_size > 1 and "DeepSeek" in base_model),
|
||||
enable_prefix_caching=False,
|
||||
max_model_len=16,
|
||||
max_num_seqs=1,
|
||||
) as llm:
|
||||
llm.collective_rpc("reload_weights", kwargs={"weights_path": mul_model})
|
||||
mul_perp = llm.generate_prompt_perplexity(["3 4 = 12"], mask=["3 4 ="])[0]
|
||||
|
||||
@@ -1006,14 +1006,17 @@ class Fp8OnlineMoEMethod(Fp8MoEMethod):
|
||||
initialize_online_processing(layer)
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# TODO(@ksayers): inplace fp8 quant kernel, initialize scales with ones
|
||||
if getattr(layer, "_already_called_process_weights_after_loading", False):
|
||||
return
|
||||
|
||||
fp8_dtype = current_platform.fp8_dtype()
|
||||
w13 = torch.empty_like(layer.w13_weight, dtype=fp8_dtype)
|
||||
w2 = torch.empty_like(layer.w2_weight, dtype=fp8_dtype)
|
||||
w13_scale = torch.ones(layer.num_experts, dtype=torch.float32)
|
||||
w2_scale = torch.ones(layer.num_experts, dtype=torch.float32)
|
||||
w13_scale = torch.ones(
|
||||
layer.num_experts, device=w13.device, dtype=torch.float32
|
||||
)
|
||||
w2_scale = torch.ones(layer.num_experts, device=w2.device, dtype=torch.float32)
|
||||
layer.w13_input_scale = None
|
||||
layer.w2_input_scale = None
|
||||
|
||||
|
||||
@@ -49,7 +49,10 @@ def get_layerwise_info(layer: torch.nn.Module) -> LayerReloadingInfo:
|
||||
information existed, a new entry is constructed
|
||||
"""
|
||||
if layer not in LAYERWISE_INFO:
|
||||
LAYERWISE_INFO[layer] = LayerReloadingInfo()
|
||||
LAYERWISE_INFO[layer] = LayerReloadingInfo(
|
||||
restore_metadata=({}, {}),
|
||||
restore_device=torch.get_default_device(),
|
||||
)
|
||||
|
||||
return LAYERWISE_INFO[layer]
|
||||
|
||||
@@ -64,6 +67,7 @@ def record_metadata_for_reloading(model: torch.nn.Module):
|
||||
for layer in model.modules():
|
||||
info = get_layerwise_info(layer)
|
||||
info.restore_metadata = capture_layer_to_meta(layer)
|
||||
info.restore_device = torch.get_default_device()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@@ -99,10 +103,18 @@ def initialize_layerwise_reload(model: torch.nn.Module):
|
||||
# Restore layer parameters/buffers onto meta device
|
||||
restore_layer_on_meta(layer, info)
|
||||
|
||||
# Wrap weight loaders to buffer loading
|
||||
initialize_online_processing(layer)
|
||||
|
||||
|
||||
def initialize_online_processing(layer: torch.nn.Module):
|
||||
"""
|
||||
Wrap a layer's weight loaders with online processing loaders.
|
||||
Called by either `initialize_layerwise_reload` or an online quantization scheme,
|
||||
prevents double wrapping in the case of online quantization + reloading
|
||||
|
||||
:param layer: layer whose parameter weight loaders will be wrapped
|
||||
"""
|
||||
info = get_layerwise_info(layer)
|
||||
|
||||
# Track loading progress to determine when to process/copy
|
||||
@@ -211,7 +223,7 @@ def finalize_layerwise_processing(model: torch.nn.Module, model_config: ModelCon
|
||||
elif info.load_numel <= 0:
|
||||
# first load but received no weights. This happens on dummy load
|
||||
if info.kernel_tensors is None:
|
||||
materialize_layer(layer)
|
||||
materialize_layer(layer, info)
|
||||
|
||||
# reloading: place kernel tensors back as a fallback
|
||||
else:
|
||||
@@ -244,7 +256,7 @@ def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
|
||||
4. Copies processed values back to original tensor storage
|
||||
"""
|
||||
# Materialize layer tensors onto device
|
||||
materialize_layer(layer)
|
||||
materialize_layer(layer, info)
|
||||
|
||||
# Reset online quantization flag so process_weights_after_loading
|
||||
# will run again during reload
|
||||
|
||||
@@ -94,14 +94,15 @@ def restore_layer_on_meta(layer: torch.nn.Module, info: LayerReloadingInfo):
|
||||
layer.register_buffer(name, buffer)
|
||||
|
||||
|
||||
def materialize_layer(layer: torch.nn.Module) -> None:
|
||||
def materialize_layer(layer: torch.nn.Module, info: LayerReloadingInfo):
|
||||
"""Materialize all meta tensors in a layer to actual tensors."""
|
||||
if layer.__class__.__name__ in SKIP_MODULES:
|
||||
return
|
||||
|
||||
for name, tensor in get_layer_tensors(layer).items():
|
||||
if name not in SKIP_TENSORS:
|
||||
setattr(layer, name, materialize_meta_tensor(tensor))
|
||||
with info.restore_device:
|
||||
for name, tensor in get_layer_tensors(layer).items():
|
||||
if name not in SKIP_TENSORS:
|
||||
setattr(layer, name, materialize_meta_tensor(tensor))
|
||||
|
||||
|
||||
class CopyCounter(TorchDispatchMode):
|
||||
|
||||
@@ -13,21 +13,26 @@ LayerTensors = tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]
|
||||
|
||||
@dataclass
|
||||
class LayerReloadingInfo:
|
||||
# model format (meta), populated by `record_metadata_for_reloading`
|
||||
restore_metadata: LayerTensors = field(default_factory=lambda: ({}, {}))
|
||||
# model format metadata, recorded by `record_metadata_for_reloading`
|
||||
restore_metadata: LayerTensors
|
||||
|
||||
# kernel format (device), used to copy into when reloading only
|
||||
kernel_tensors: LayerTensors | None = None
|
||||
# device to materialize layers with, recorded by `record_metadata_for_reloading`
|
||||
restore_device: torch.device
|
||||
|
||||
# track how many restored elements are ready for loading
|
||||
# track how many elements are ready for loading, used by `online_process_loader`
|
||||
load_numel: int = 0
|
||||
load_numel_total: int | None = None
|
||||
|
||||
# stores arguments and tensors ready for loading
|
||||
# used by `online_process_loader` to buffer args and tensors until ready to load
|
||||
loaded_weights: list[tuple[str, BoundArguments]] = field(default_factory=list)
|
||||
|
||||
# kernel formatted tensors, copied into by `_layerwise_process` when reloading
|
||||
kernel_tensors: LayerTensors | None = None
|
||||
|
||||
def reset(self):
|
||||
self.__init__(restore_metadata=self.restore_metadata) # type: ignore[misc]
|
||||
self.__init__( # type: ignore[misc]
|
||||
restore_metadata=self.restore_metadata, restore_device=self.restore_device
|
||||
)
|
||||
|
||||
def can_load(self) -> bool:
|
||||
return self.load_numel_total is not None
|
||||
|
||||
@@ -4943,28 +4943,24 @@ class GPUModelRunner(
|
||||
|
||||
# begin loading weights
|
||||
logger.info_once("Reloading weights inplace...", scope="local")
|
||||
load_device = (
|
||||
self.vllm_config.load_config.device or self.vllm_config.device_config.device
|
||||
)
|
||||
with torch.device(load_device):
|
||||
if is_checkpoint_format:
|
||||
# load weights from checkpoint/ original model format
|
||||
initialize_layerwise_reload(model)
|
||||
loaded_weights = model.load_weights(weights_iterator)
|
||||
finalize_layerwise_reload(model, self.model_config)
|
||||
if is_checkpoint_format:
|
||||
# load weights from checkpoint/ original model format
|
||||
initialize_layerwise_reload(model)
|
||||
loaded_weights = model.load_weights(weights_iterator)
|
||||
finalize_layerwise_reload(model, self.model_config)
|
||||
|
||||
else:
|
||||
# load weights from kernel format
|
||||
logger.warning_once(
|
||||
"Reloading with `is_checkpoint_format=True` requires that "
|
||||
"weights be in kernel format and already sharded",
|
||||
scope="local",
|
||||
)
|
||||
loaded_weights = set()
|
||||
for name, loaded_weight in weights_iterator:
|
||||
param = model.get_parameter(name) # TODO: buffers?
|
||||
param.copy_(loaded_weight)
|
||||
loaded_weights.add(name)
|
||||
else:
|
||||
# load weights from kernel format
|
||||
logger.warning_once(
|
||||
"Reloading with `is_checkpoint_format=True` requires that "
|
||||
"weights be in kernel format and already sharded",
|
||||
scope="local",
|
||||
)
|
||||
loaded_weights = set()
|
||||
for name, loaded_weight in weights_iterator:
|
||||
param = model.get_parameter(name) # TODO: buffers?
|
||||
param.copy_(loaded_weight)
|
||||
loaded_weights.add(name)
|
||||
|
||||
# logging and validation
|
||||
counter_after_reloading = time.perf_counter()
|
||||
|
||||
Reference in New Issue
Block a user