[QeRL] Layerwise Reloading (#32133)
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
This commit is contained in:
@@ -1,11 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import importlib.metadata
|
||||
import importlib.util
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
DTYPE = ["bfloat16"]
|
||||
@@ -105,8 +105,8 @@ def test_opt_125m_awq_int4wo_model_loading_with_params(vllm_runner):
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
|
||||
def test_online_quant_config_dict_json(vllm_runner):
|
||||
"""Testing on the fly quantization, load_weights integration point,
|
||||
def test_online_quant_config_dict_json(vllm_runner, enable_pickle):
|
||||
"""Testing online quantization, load_weights integration point,
|
||||
with config dict serialized to json string
|
||||
"""
|
||||
torch._dynamo.reset()
|
||||
@@ -135,7 +135,18 @@ def test_online_quant_config_dict_json(vllm_runner):
|
||||
) as llm:
|
||||
output = llm.generate_greedy(["The capital of France is"], max_tokens=4)
|
||||
|
||||
assert output
|
||||
load_config = llm.llm.llm_engine.vllm_config.load_config
|
||||
model_config = llm.llm.llm_engine.vllm_config.model_config
|
||||
|
||||
def load_weights(model):
|
||||
model_loader = get_model_loader(load_config)
|
||||
weights_iterator = model_loader.get_all_weights(model_config, model)
|
||||
model.load_weights(weights_iterator)
|
||||
|
||||
llm.apply_model(load_weights)
|
||||
|
||||
reload_output = llm.generate_greedy(["The capital of France is"], max_tokens=4)
|
||||
assert output[0][0] == reload_output[0][0]
|
||||
|
||||
|
||||
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
|
||||
|
||||
Reference in New Issue
Block a user