[Bugfix][Refactor] Unify model management in frontend (#11660)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
@@ -4,7 +4,7 @@ import pytest
|
||||
|
||||
from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
||||
validate_parsed_serve_args)
|
||||
from vllm.entrypoints.openai.serving_engine import LoRAModulePath
|
||||
from vllm.entrypoints.openai.serving_models import LoRAModulePath
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
from ...utils import VLLM_PATH
|
||||
|
||||
@@ -55,7 +55,10 @@ def server_with_lora_modules_json(zephyr_lora_files):
|
||||
"64",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
# Enable the /v1/load_lora_adapter endpoint
|
||||
envs = {"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True"}
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=envs) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
||||
@@ -67,8 +70,8 @@ async def client_for_lora_lineage(server_with_lora_modules_json):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_check_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI,
|
||||
zephyr_lora_files):
|
||||
async def test_static_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI,
|
||||
zephyr_lora_files):
|
||||
models = await client_for_lora_lineage.models.list()
|
||||
models = models.data
|
||||
served_model = models[0]
|
||||
@@ -81,3 +84,26 @@ async def test_check_lora_lineage(client_for_lora_lineage: openai.AsyncOpenAI,
|
||||
assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models)
|
||||
assert lora_models[0].id == "zephyr-lora"
|
||||
assert lora_models[1].id == "zephyr-lora2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dynamic_lora_lineage(
|
||||
client_for_lora_lineage: openai.AsyncOpenAI, zephyr_lora_files):
|
||||
|
||||
response = await client_for_lora_lineage.post("load_lora_adapter",
|
||||
cast_to=str,
|
||||
body={
|
||||
"lora_name":
|
||||
"zephyr-lora-3",
|
||||
"lora_path":
|
||||
zephyr_lora_files
|
||||
})
|
||||
# Ensure adapter loads before querying /models
|
||||
assert "success" in response
|
||||
|
||||
models = await client_for_lora_lineage.models.list()
|
||||
models = models.data
|
||||
dynamic_lora_model = models[-1]
|
||||
assert dynamic_lora_model.root == zephyr_lora_files
|
||||
assert dynamic_lora_model.parent == MODEL_NAME
|
||||
assert dynamic_lora_model.id == "zephyr-lora-3"
|
||||
|
||||
@@ -8,7 +8,8 @@ from vllm.config import MultiModalConfig
|
||||
from vllm.engine.multiprocessing.client import MQLLMEngineClient
|
||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath
|
||||
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||
OpenAIServingModels)
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
MODEL_NAME = "openai-community/gpt2"
|
||||
@@ -50,14 +51,13 @@ async def _async_serving_chat_init():
|
||||
engine = MockEngine()
|
||||
model_config = await engine.get_model_config()
|
||||
|
||||
models = OpenAIServingModels(model_config, BASE_MODEL_PATHS)
|
||||
serving_completion = OpenAIServingChat(engine,
|
||||
model_config,
|
||||
BASE_MODEL_PATHS,
|
||||
models,
|
||||
response_role="assistant",
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
chat_template_content_format="auto",
|
||||
lora_modules=None,
|
||||
prompt_adapters=None,
|
||||
request_logger=None)
|
||||
return serving_completion
|
||||
|
||||
@@ -72,14 +72,14 @@ def test_serving_chat_should_set_correct_max_tokens():
|
||||
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
|
||||
mock_engine.errored = False
|
||||
|
||||
models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS,
|
||||
model_config=MockModelConfig())
|
||||
serving_chat = OpenAIServingChat(mock_engine,
|
||||
MockModelConfig(),
|
||||
BASE_MODEL_PATHS,
|
||||
models,
|
||||
response_role="assistant",
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
chat_template_content_format="auto",
|
||||
lora_modules=None,
|
||||
prompt_adapters=None,
|
||||
request_logger=None)
|
||||
req = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
@@ -115,14 +115,14 @@ def test_serving_chat_could_load_correct_generation_config():
|
||||
mock_engine.errored = False
|
||||
|
||||
# Initialize the serving chat
|
||||
models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS,
|
||||
model_config=mock_model_config)
|
||||
serving_chat = OpenAIServingChat(mock_engine,
|
||||
mock_model_config,
|
||||
BASE_MODEL_PATHS,
|
||||
models,
|
||||
response_role="assistant",
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
chat_template_content_format="auto",
|
||||
lora_modules=None,
|
||||
prompt_adapters=None,
|
||||
request_logger=None)
|
||||
req = ChatCompletionRequest(
|
||||
model=MODEL_NAME,
|
||||
|
||||
@@ -4,11 +4,11 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
||||
LoadLoraAdapterRequest,
|
||||
UnloadLoraAdapterRequest)
|
||||
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
|
||||
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
|
||||
OpenAIServingModels)
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
MODEL_NAME = "meta-llama/Llama-2-7b"
|
||||
@@ -19,47 +19,45 @@ LORA_UNLOADING_SUCCESS_MESSAGE = (
|
||||
"Success: LoRA adapter '{lora_name}' removed successfully.")
|
||||
|
||||
|
||||
async def _async_serving_engine_init():
|
||||
mock_engine_client = MagicMock(spec=EngineClient)
|
||||
async def _async_serving_models_init() -> OpenAIServingModels:
|
||||
mock_model_config = MagicMock(spec=ModelConfig)
|
||||
# Set the max_model_len attribute to avoid missing attribute
|
||||
mock_model_config.max_model_len = 2048
|
||||
|
||||
serving_engine = OpenAIServing(mock_engine_client,
|
||||
mock_model_config,
|
||||
BASE_MODEL_PATHS,
|
||||
lora_modules=None,
|
||||
prompt_adapters=None,
|
||||
request_logger=None)
|
||||
return serving_engine
|
||||
serving_models = OpenAIServingModels(base_model_paths=BASE_MODEL_PATHS,
|
||||
model_config=mock_model_config,
|
||||
lora_modules=None,
|
||||
prompt_adapters=None)
|
||||
|
||||
return serving_models
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_serving_model_name():
|
||||
serving_engine = await _async_serving_engine_init()
|
||||
assert serving_engine._get_model_name(None) == MODEL_NAME
|
||||
serving_models = await _async_serving_models_init()
|
||||
assert serving_models.model_name(None) == MODEL_NAME
|
||||
request = LoRARequest(lora_name="adapter",
|
||||
lora_path="/path/to/adapter2",
|
||||
lora_int_id=1)
|
||||
assert serving_engine._get_model_name(request) == request.lora_name
|
||||
assert serving_models.model_name(request) == request.lora_name
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_lora_adapter_success():
|
||||
serving_engine = await _async_serving_engine_init()
|
||||
serving_models = await _async_serving_models_init()
|
||||
request = LoadLoraAdapterRequest(lora_name="adapter",
|
||||
lora_path="/path/to/adapter2")
|
||||
response = await serving_engine.load_lora_adapter(request)
|
||||
response = await serving_models.load_lora_adapter(request)
|
||||
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter')
|
||||
assert len(serving_engine.lora_requests) == 1
|
||||
assert serving_engine.lora_requests[0].lora_name == "adapter"
|
||||
assert len(serving_models.lora_requests) == 1
|
||||
assert serving_models.lora_requests[0].lora_name == "adapter"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_lora_adapter_missing_fields():
|
||||
serving_engine = await _async_serving_engine_init()
|
||||
serving_models = await _async_serving_models_init()
|
||||
request = LoadLoraAdapterRequest(lora_name="", lora_path="")
|
||||
response = await serving_engine.load_lora_adapter(request)
|
||||
response = await serving_models.load_lora_adapter(request)
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert response.type == "InvalidUserInput"
|
||||
assert response.code == HTTPStatus.BAD_REQUEST
|
||||
@@ -67,43 +65,43 @@ async def test_load_lora_adapter_missing_fields():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_lora_adapter_duplicate():
|
||||
serving_engine = await _async_serving_engine_init()
|
||||
serving_models = await _async_serving_models_init()
|
||||
request = LoadLoraAdapterRequest(lora_name="adapter1",
|
||||
lora_path="/path/to/adapter1")
|
||||
response = await serving_engine.load_lora_adapter(request)
|
||||
response = await serving_models.load_lora_adapter(request)
|
||||
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(
|
||||
lora_name='adapter1')
|
||||
assert len(serving_engine.lora_requests) == 1
|
||||
assert len(serving_models.lora_requests) == 1
|
||||
|
||||
request = LoadLoraAdapterRequest(lora_name="adapter1",
|
||||
lora_path="/path/to/adapter1")
|
||||
response = await serving_engine.load_lora_adapter(request)
|
||||
response = await serving_models.load_lora_adapter(request)
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert response.type == "InvalidUserInput"
|
||||
assert response.code == HTTPStatus.BAD_REQUEST
|
||||
assert len(serving_engine.lora_requests) == 1
|
||||
assert len(serving_models.lora_requests) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unload_lora_adapter_success():
|
||||
serving_engine = await _async_serving_engine_init()
|
||||
serving_models = await _async_serving_models_init()
|
||||
request = LoadLoraAdapterRequest(lora_name="adapter1",
|
||||
lora_path="/path/to/adapter1")
|
||||
response = await serving_engine.load_lora_adapter(request)
|
||||
assert len(serving_engine.lora_requests) == 1
|
||||
response = await serving_models.load_lora_adapter(request)
|
||||
assert len(serving_models.lora_requests) == 1
|
||||
|
||||
request = UnloadLoraAdapterRequest(lora_name="adapter1")
|
||||
response = await serving_engine.unload_lora_adapter(request)
|
||||
response = await serving_models.unload_lora_adapter(request)
|
||||
assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format(
|
||||
lora_name='adapter1')
|
||||
assert len(serving_engine.lora_requests) == 0
|
||||
assert len(serving_models.lora_requests) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unload_lora_adapter_missing_fields():
|
||||
serving_engine = await _async_serving_engine_init()
|
||||
serving_models = await _async_serving_models_init()
|
||||
request = UnloadLoraAdapterRequest(lora_name="", lora_int_id=None)
|
||||
response = await serving_engine.unload_lora_adapter(request)
|
||||
response = await serving_models.unload_lora_adapter(request)
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert response.type == "InvalidUserInput"
|
||||
assert response.code == HTTPStatus.BAD_REQUEST
|
||||
@@ -111,9 +109,9 @@ async def test_unload_lora_adapter_missing_fields():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unload_lora_adapter_not_found():
|
||||
serving_engine = await _async_serving_engine_init()
|
||||
serving_models = await _async_serving_models_init()
|
||||
request = UnloadLoraAdapterRequest(lora_name="nonexistent_adapter")
|
||||
response = await serving_engine.unload_lora_adapter(request)
|
||||
response = await serving_models.unload_lora_adapter(request)
|
||||
assert isinstance(response, ErrorResponse)
|
||||
assert response.type == "InvalidUserInput"
|
||||
assert response.code == HTTPStatus.BAD_REQUEST
|
||||
Reference in New Issue
Block a user