[Bugfix] Validate lora adapters to avoid crashing server (#11727)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@@ -662,7 +662,7 @@ def build_app(args: Namespace) -> FastAPI:
|
||||
return app
|
||||
|
||||
|
||||
def init_app_state(
|
||||
async def init_app_state(
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
state: State,
|
||||
@@ -690,12 +690,13 @@ def init_app_state(
|
||||
logger.info("Using supplied chat template:\n%s", resolved_chat_template)
|
||||
|
||||
state.openai_serving_models = OpenAIServingModels(
|
||||
engine_client=engine_client,
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=args.lora_modules,
|
||||
prompt_adapters=args.prompt_adapters,
|
||||
)
|
||||
# TODO: The chat template is now broken for lora adapters :(
|
||||
await state.openai_serving_models.init_static_loras()
|
||||
state.openai_serving_chat = OpenAIServingChat(
|
||||
engine_client,
|
||||
model_config,
|
||||
@@ -794,7 +795,7 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
||||
app = build_app(args)
|
||||
|
||||
model_config = await engine_client.get_model_config()
|
||||
init_app_state(engine_client, model_config, app.state, args)
|
||||
await init_app_state(engine_client, model_config, app.state, args)
|
||||
|
||||
shutdown_task = await serve_http(
|
||||
app,
|
||||
|
||||
@@ -215,6 +215,7 @@ async def main(args):
|
||||
|
||||
# Create the openai serving objects.
|
||||
openai_serving_models = OpenAIServingModels(
|
||||
engine_client=engine,
|
||||
model_config=model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
lora_modules=None,
|
||||
|
||||
@@ -5,15 +5,19 @@ from http import HTTPStatus
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
||||
LoadLoraAdapterRequest,
|
||||
ModelCard, ModelList,
|
||||
ModelPermission,
|
||||
UnloadLoraAdapterRequest)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.utils import AtomicCounter
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelPath:
|
||||
@@ -45,6 +49,7 @@ class OpenAIServingModels:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_client: EngineClient,
|
||||
model_config: ModelConfig,
|
||||
base_model_paths: List[BaseModelPath],
|
||||
*,
|
||||
@@ -55,20 +60,11 @@ class OpenAIServingModels:
|
||||
|
||||
self.base_model_paths = base_model_paths
|
||||
self.max_model_len = model_config.max_model_len
|
||||
self.engine_client = engine_client
|
||||
|
||||
self.static_lora_modules = lora_modules
|
||||
self.lora_requests: List[LoRARequest] = []
|
||||
self.lora_id_counter = AtomicCounter(0)
|
||||
self.lora_requests = []
|
||||
if lora_modules is not None:
|
||||
self.lora_requests = [
|
||||
LoRARequest(lora_name=lora.name,
|
||||
lora_int_id=i,
|
||||
lora_path=lora.path,
|
||||
base_model_name=lora.base_model_name
|
||||
if lora.base_model_name
|
||||
and self.is_base_model(lora.base_model_name) else
|
||||
self.base_model_paths[0].name)
|
||||
for i, lora in enumerate(lora_modules, start=1)
|
||||
]
|
||||
|
||||
self.prompt_adapter_requests = []
|
||||
if prompt_adapters is not None:
|
||||
@@ -84,6 +80,19 @@ class OpenAIServingModels:
|
||||
prompt_adapter_local_path=prompt_adapter.local_path,
|
||||
prompt_adapter_num_virtual_tokens=num_virtual_tokens))
|
||||
|
||||
async def init_static_loras(self):
|
||||
"""Loads all static LoRA modules.
|
||||
Raises if any fail to load"""
|
||||
if self.static_lora_modules is None:
|
||||
return
|
||||
for lora in self.static_lora_modules:
|
||||
load_request = LoadLoraAdapterRequest(lora_path=lora.path,
|
||||
lora_name=lora.name)
|
||||
load_result = await self.load_lora_adapter(
|
||||
request=load_request, base_model_name=lora.base_model_name)
|
||||
if isinstance(load_result, ErrorResponse):
|
||||
raise ValueError(load_result.message)
|
||||
|
||||
def is_base_model(self, model_name):
|
||||
return any(model.name == model_name for model in self.base_model_paths)
|
||||
|
||||
@@ -129,17 +138,47 @@ class OpenAIServingModels:
|
||||
|
||||
async def load_lora_adapter(
|
||||
self,
|
||||
request: LoadLoraAdapterRequest) -> Union[ErrorResponse, str]:
|
||||
request: LoadLoraAdapterRequest,
|
||||
base_model_name: Optional[str] = None
|
||||
) -> Union[ErrorResponse, str]:
|
||||
error_check_ret = await self._check_load_lora_adapter_request(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
lora_name, lora_path = request.lora_name, request.lora_path
|
||||
unique_id = self.lora_id_counter.inc(1)
|
||||
self.lora_requests.append(
|
||||
LoRARequest(lora_name=lora_name,
|
||||
lora_int_id=unique_id,
|
||||
lora_path=lora_path))
|
||||
lora_request = LoRARequest(lora_name=lora_name,
|
||||
lora_int_id=unique_id,
|
||||
lora_path=lora_path)
|
||||
if base_model_name is not None and self.is_base_model(base_model_name):
|
||||
lora_request.base_model_name = base_model_name
|
||||
|
||||
# Validate that the adapter can be loaded into the engine
|
||||
# This will also pre-load it for incoming requests
|
||||
try:
|
||||
await self.engine_client.add_lora(lora_request)
|
||||
except ValueError as e:
|
||||
# Adapter not found or lora configuration errors
|
||||
if "No adapter found" in str(e):
|
||||
return create_error_response(message=str(e),
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND)
|
||||
else:
|
||||
return create_error_response(
|
||||
message=str(e),
|
||||
err_type="BadRequestError",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
except BaseException as e:
|
||||
# Some other unexpected problem loading the adapter, e.g. malformed
|
||||
# input files.
|
||||
# More detailed error messages for the user would be nicer here
|
||||
return create_error_response(message=str(e),
|
||||
err_type="BadRequestError",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
self.lora_requests.append(lora_request)
|
||||
logger.info("Loaded new LoRA adapter: name '%s', path '%s'", lora_name,
|
||||
lora_path)
|
||||
return f"Success: LoRA adapter '{lora_name}' added successfully."
|
||||
|
||||
async def unload_lora_adapter(
|
||||
@@ -155,6 +194,7 @@ class OpenAIServingModels:
|
||||
lora_request for lora_request in self.lora_requests
|
||||
if lora_request.lora_name != lora_name
|
||||
]
|
||||
logger.info("Removed LoRA adapter: name '%s'", lora_name)
|
||||
return f"Success: LoRA adapter '{lora_name}' removed successfully."
|
||||
|
||||
async def _check_load_lora_adapter_request(
|
||||
@@ -195,8 +235,8 @@ class OpenAIServingModels:
|
||||
return create_error_response(
|
||||
message=
|
||||
f"The lora adapter '{request.lora_name}' cannot be found.",
|
||||
err_type="InvalidUserInput",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user