[Misc] Enable vLLM to Dynamically Load LoRA from a Remote Server (#10546)
Signed-off-by: Angky William <angkywilliam@Angkys-MacBook-Pro.local> Co-authored-by: Angky William <angkywilliam@Angkys-MacBook-Pro.local>
This commit is contained in:
@@ -10,6 +10,7 @@ from fastapi import Request
|
||||
from pydantic import Field
|
||||
from starlette.datastructures import Headers
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
# yapf conflicts with isort for this block
|
||||
@@ -125,18 +126,29 @@ class OpenAIServing:
|
||||
self,
|
||||
request: AnyRequest,
|
||||
) -> Optional[ErrorResponse]:
|
||||
|
||||
error_response = None
|
||||
|
||||
if self._is_model_supported(request.model):
|
||||
return None
|
||||
if request.model in [
|
||||
lora.lora_name for lora in self.models.lora_requests
|
||||
]:
|
||||
return None
|
||||
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and (
|
||||
load_result := await self.models.resolve_lora(request.model)):
|
||||
if isinstance(load_result, LoRARequest):
|
||||
return None
|
||||
if isinstance(load_result, ErrorResponse) and \
|
||||
load_result.code == HTTPStatus.BAD_REQUEST.value:
|
||||
error_response = load_result
|
||||
if request.model in [
|
||||
prompt_adapter.prompt_adapter_name
|
||||
for prompt_adapter in self.models.prompt_adapter_requests
|
||||
]:
|
||||
return None
|
||||
return self.create_error_response(
|
||||
|
||||
return error_response or self.create_error_response(
|
||||
message=f"The model `{request.model}` does not exist.",
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND)
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
import json
|
||||
import pathlib
|
||||
from asyncio import Lock
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from typing import Optional, Union
|
||||
@@ -15,6 +17,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
||||
UnloadLoRAAdapterRequest)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.utils import AtomicCounter
|
||||
|
||||
@@ -63,11 +66,19 @@ class OpenAIServingModels:
|
||||
self.base_model_paths = base_model_paths
|
||||
self.max_model_len = model_config.max_model_len
|
||||
self.engine_client = engine_client
|
||||
self.model_config = model_config
|
||||
|
||||
self.static_lora_modules = lora_modules
|
||||
self.lora_requests: list[LoRARequest] = []
|
||||
self.lora_id_counter = AtomicCounter(0)
|
||||
|
||||
self.lora_resolvers: list[LoRAResolver] = []
|
||||
for lora_resolver_name in LoRAResolverRegistry.get_supported_resolvers(
|
||||
):
|
||||
self.lora_resolvers.append(
|
||||
LoRAResolverRegistry.get_resolver(lora_resolver_name))
|
||||
self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock)
|
||||
|
||||
self.prompt_adapter_requests = []
|
||||
if prompt_adapters is not None:
|
||||
for i, prompt_adapter in enumerate(prompt_adapters, start=1):
|
||||
@@ -234,6 +245,65 @@ class OpenAIServingModels:
|
||||
|
||||
return None
|
||||
|
||||
async def resolve_lora(
|
||||
self, lora_name: str) -> Union[LoRARequest, ErrorResponse]:
|
||||
"""Attempt to resolve a LoRA adapter using available resolvers.
|
||||
|
||||
Args:
|
||||
lora_name: Name/identifier of the LoRA adapter
|
||||
|
||||
Returns:
|
||||
LoRARequest if found and loaded successfully.
|
||||
ErrorResponse (404) if no resolver finds the adapter.
|
||||
ErrorResponse (400) if adapter(s) are found but none load.
|
||||
"""
|
||||
async with self.lora_resolver_lock[lora_name]:
|
||||
# First check if this LoRA is already loaded
|
||||
for existing in self.lora_requests:
|
||||
if existing.lora_name == lora_name:
|
||||
return existing
|
||||
|
||||
base_model_name = self.model_config.model
|
||||
unique_id = self.lora_id_counter.inc(1)
|
||||
found_adapter = False
|
||||
|
||||
# Try to resolve using available resolvers
|
||||
for resolver in self.lora_resolvers:
|
||||
lora_request = await resolver.resolve_lora(
|
||||
base_model_name, lora_name)
|
||||
|
||||
if lora_request is not None:
|
||||
found_adapter = True
|
||||
lora_request.lora_int_id = unique_id
|
||||
|
||||
try:
|
||||
await self.engine_client.add_lora(lora_request)
|
||||
self.lora_requests.append(lora_request)
|
||||
logger.info(
|
||||
"Resolved and loaded LoRA adapter '%s' using %s",
|
||||
lora_name, resolver.__class__.__name__)
|
||||
return lora_request
|
||||
except BaseException as e:
|
||||
logger.warning(
|
||||
"Failed to load LoRA '%s' resolved by %s: %s. "
|
||||
"Trying next resolver.", lora_name,
|
||||
resolver.__class__.__name__, e)
|
||||
continue
|
||||
|
||||
if found_adapter:
|
||||
# An adapter was found, but all attempts to load it failed.
|
||||
return create_error_response(
|
||||
message=(f"LoRA adapter '{lora_name}' was found "
|
||||
"but could not be loaded."),
|
||||
err_type="BadRequestError",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
else:
|
||||
# No adapter was found
|
||||
return create_error_response(
|
||||
message=f"LoRA adapter {lora_name} does not exist",
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND)
|
||||
|
||||
|
||||
def create_error_response(
|
||||
message: str,
|
||||
|
||||
Reference in New Issue
Block a user