[Bugfix] Validate lora adapters to avoid crashing server (#11727)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Joe Runde
2025-01-10 00:56:36 -07:00
committed by GitHub
parent cf5f000d21
commit ac2f3f7fee
15 changed files with 460 additions and 172 deletions

View File

@@ -662,7 +662,7 @@ def build_app(args: Namespace) -> FastAPI:
return app
def init_app_state(
async def init_app_state(
engine_client: EngineClient,
model_config: ModelConfig,
state: State,
@@ -690,12 +690,13 @@ def init_app_state(
logger.info("Using supplied chat template:\n%s", resolved_chat_template)
state.openai_serving_models = OpenAIServingModels(
engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
)
# TODO: The chat template is now broken for lora adapters :(
await state.openai_serving_models.init_static_loras()
state.openai_serving_chat = OpenAIServingChat(
engine_client,
model_config,
@@ -794,7 +795,7 @@ async def run_server(args, **uvicorn_kwargs) -> None:
app = build_app(args)
model_config = await engine_client.get_model_config()
init_app_state(engine_client, model_config, app.state, args)
await init_app_state(engine_client, model_config, app.state, args)
shutdown_task = await serve_http(
app,

View File

@@ -215,6 +215,7 @@ async def main(args):
# Create the openai serving objects.
openai_serving_models = OpenAIServingModels(
engine_client=engine,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=None,

View File

@@ -5,15 +5,19 @@ from http import HTTPStatus
from typing import List, Optional, Union
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.protocol import (ErrorResponse,
LoadLoraAdapterRequest,
ModelCard, ModelList,
ModelPermission,
UnloadLoraAdapterRequest)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.utils import AtomicCounter
logger = init_logger(__name__)
@dataclass
class BaseModelPath:
@@ -45,6 +49,7 @@ class OpenAIServingModels:
def __init__(
self,
engine_client: EngineClient,
model_config: ModelConfig,
base_model_paths: List[BaseModelPath],
*,
@@ -55,20 +60,11 @@ class OpenAIServingModels:
self.base_model_paths = base_model_paths
self.max_model_len = model_config.max_model_len
self.engine_client = engine_client
self.static_lora_modules = lora_modules
self.lora_requests: List[LoRARequest] = []
self.lora_id_counter = AtomicCounter(0)
self.lora_requests = []
if lora_modules is not None:
self.lora_requests = [
LoRARequest(lora_name=lora.name,
lora_int_id=i,
lora_path=lora.path,
base_model_name=lora.base_model_name
if lora.base_model_name
and self.is_base_model(lora.base_model_name) else
self.base_model_paths[0].name)
for i, lora in enumerate(lora_modules, start=1)
]
self.prompt_adapter_requests = []
if prompt_adapters is not None:
@@ -84,6 +80,19 @@ class OpenAIServingModels:
prompt_adapter_local_path=prompt_adapter.local_path,
prompt_adapter_num_virtual_tokens=num_virtual_tokens))
async def init_static_loras(self):
"""Loads all static LoRA modules.
Raises if any fail to load"""
if self.static_lora_modules is None:
return
for lora in self.static_lora_modules:
load_request = LoadLoraAdapterRequest(lora_path=lora.path,
lora_name=lora.name)
load_result = await self.load_lora_adapter(
request=load_request, base_model_name=lora.base_model_name)
if isinstance(load_result, ErrorResponse):
raise ValueError(load_result.message)
def is_base_model(self, model_name):
return any(model.name == model_name for model in self.base_model_paths)
@@ -129,17 +138,47 @@ class OpenAIServingModels:
async def load_lora_adapter(
self,
request: LoadLoraAdapterRequest) -> Union[ErrorResponse, str]:
request: LoadLoraAdapterRequest,
base_model_name: Optional[str] = None
) -> Union[ErrorResponse, str]:
error_check_ret = await self._check_load_lora_adapter_request(request)
if error_check_ret is not None:
return error_check_ret
lora_name, lora_path = request.lora_name, request.lora_path
unique_id = self.lora_id_counter.inc(1)
self.lora_requests.append(
LoRARequest(lora_name=lora_name,
lora_int_id=unique_id,
lora_path=lora_path))
lora_request = LoRARequest(lora_name=lora_name,
lora_int_id=unique_id,
lora_path=lora_path)
if base_model_name is not None and self.is_base_model(base_model_name):
lora_request.base_model_name = base_model_name
# Validate that the adapter can be loaded into the engine
# This will also pre-load it for incoming requests
try:
await self.engine_client.add_lora(lora_request)
except ValueError as e:
# Adapter not found or lora configuration errors
if "No adapter found" in str(e):
return create_error_response(message=str(e),
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND)
else:
return create_error_response(
message=str(e),
err_type="BadRequestError",
status_code=HTTPStatus.BAD_REQUEST)
except BaseException as e:
# Some other unexpected problem loading the adapter, e.g. malformed
# input files.
# More detailed error messages for the user would be nicer here
return create_error_response(message=str(e),
err_type="BadRequestError",
status_code=HTTPStatus.BAD_REQUEST)
self.lora_requests.append(lora_request)
logger.info("Loaded new LoRA adapter: name '%s', path '%s'", lora_name,
lora_path)
return f"Success: LoRA adapter '{lora_name}' added successfully."
async def unload_lora_adapter(
@@ -155,6 +194,7 @@ class OpenAIServingModels:
lora_request for lora_request in self.lora_requests
if lora_request.lora_name != lora_name
]
logger.info("Removed LoRA adapter: name '%s'", lora_name)
return f"Success: LoRA adapter '{lora_name}' removed successfully."
async def _check_load_lora_adapter_request(
@@ -195,8 +235,8 @@ class OpenAIServingModels:
return create_error_response(
message=
f"The lora adapter '{request.lora_name}' cannot be found.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND)
return None