[Misc] Enable vLLM to Dynamically Load LoRA from a Remote Server (#10546)
Signed-off-by: Angky William <angkywilliam@Angkys-MacBook-Pro.local> Co-authored-by: Angky William <angkywilliam@Angkys-MacBook-Pro.local>
This commit is contained in:
@@ -10,6 +10,7 @@ from fastapi import Request
|
||||
from pydantic import Field
|
||||
from starlette.datastructures import Headers
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
# yapf conflicts with isort for this block
|
||||
@@ -125,18 +126,29 @@ class OpenAIServing:
|
||||
self,
|
||||
request: AnyRequest,
|
||||
) -> Optional[ErrorResponse]:
|
||||
|
||||
error_response = None
|
||||
|
||||
if self._is_model_supported(request.model):
|
||||
return None
|
||||
if request.model in [
|
||||
lora.lora_name for lora in self.models.lora_requests
|
||||
]:
|
||||
return None
|
||||
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING and request.model and (
|
||||
load_result := await self.models.resolve_lora(request.model)):
|
||||
if isinstance(load_result, LoRARequest):
|
||||
return None
|
||||
if isinstance(load_result, ErrorResponse) and \
|
||||
load_result.code == HTTPStatus.BAD_REQUEST.value:
|
||||
error_response = load_result
|
||||
if request.model in [
|
||||
prompt_adapter.prompt_adapter_name
|
||||
for prompt_adapter in self.models.prompt_adapter_requests
|
||||
]:
|
||||
return None
|
||||
return self.create_error_response(
|
||||
|
||||
return error_response or self.create_error_response(
|
||||
message=f"The model `{request.model}` does not exist.",
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND)
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
import json
|
||||
import pathlib
|
||||
from asyncio import Lock
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from typing import Optional, Union
|
||||
@@ -15,6 +17,7 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
|
||||
UnloadLoRAAdapterRequest)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.utils import AtomicCounter
|
||||
|
||||
@@ -63,11 +66,19 @@ class OpenAIServingModels:
|
||||
self.base_model_paths = base_model_paths
|
||||
self.max_model_len = model_config.max_model_len
|
||||
self.engine_client = engine_client
|
||||
self.model_config = model_config
|
||||
|
||||
self.static_lora_modules = lora_modules
|
||||
self.lora_requests: list[LoRARequest] = []
|
||||
self.lora_id_counter = AtomicCounter(0)
|
||||
|
||||
self.lora_resolvers: list[LoRAResolver] = []
|
||||
for lora_resolver_name in LoRAResolverRegistry.get_supported_resolvers(
|
||||
):
|
||||
self.lora_resolvers.append(
|
||||
LoRAResolverRegistry.get_resolver(lora_resolver_name))
|
||||
self.lora_resolver_lock: dict[str, Lock] = defaultdict(Lock)
|
||||
|
||||
self.prompt_adapter_requests = []
|
||||
if prompt_adapters is not None:
|
||||
for i, prompt_adapter in enumerate(prompt_adapters, start=1):
|
||||
@@ -234,6 +245,65 @@ class OpenAIServingModels:
|
||||
|
||||
return None
|
||||
|
||||
async def resolve_lora(
|
||||
self, lora_name: str) -> Union[LoRARequest, ErrorResponse]:
|
||||
"""Attempt to resolve a LoRA adapter using available resolvers.
|
||||
|
||||
Args:
|
||||
lora_name: Name/identifier of the LoRA adapter
|
||||
|
||||
Returns:
|
||||
LoRARequest if found and loaded successfully.
|
||||
ErrorResponse (404) if no resolver finds the adapter.
|
||||
ErrorResponse (400) if adapter(s) are found but none load.
|
||||
"""
|
||||
async with self.lora_resolver_lock[lora_name]:
|
||||
# First check if this LoRA is already loaded
|
||||
for existing in self.lora_requests:
|
||||
if existing.lora_name == lora_name:
|
||||
return existing
|
||||
|
||||
base_model_name = self.model_config.model
|
||||
unique_id = self.lora_id_counter.inc(1)
|
||||
found_adapter = False
|
||||
|
||||
# Try to resolve using available resolvers
|
||||
for resolver in self.lora_resolvers:
|
||||
lora_request = await resolver.resolve_lora(
|
||||
base_model_name, lora_name)
|
||||
|
||||
if lora_request is not None:
|
||||
found_adapter = True
|
||||
lora_request.lora_int_id = unique_id
|
||||
|
||||
try:
|
||||
await self.engine_client.add_lora(lora_request)
|
||||
self.lora_requests.append(lora_request)
|
||||
logger.info(
|
||||
"Resolved and loaded LoRA adapter '%s' using %s",
|
||||
lora_name, resolver.__class__.__name__)
|
||||
return lora_request
|
||||
except BaseException as e:
|
||||
logger.warning(
|
||||
"Failed to load LoRA '%s' resolved by %s: %s. "
|
||||
"Trying next resolver.", lora_name,
|
||||
resolver.__class__.__name__, e)
|
||||
continue
|
||||
|
||||
if found_adapter:
|
||||
# An adapter was found, but all attempts to load it failed.
|
||||
return create_error_response(
|
||||
message=(f"LoRA adapter '{lora_name}' was found "
|
||||
"but could not be loaded."),
|
||||
err_type="BadRequestError",
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
else:
|
||||
# No adapter was found
|
||||
return create_error_response(
|
||||
message=f"LoRA adapter {lora_name} does not exist",
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND)
|
||||
|
||||
|
||||
def create_error_response(
|
||||
message: str,
|
||||
|
||||
83
vllm/lora/resolver.py
Normal file
83
vllm/lora/resolver.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import AbstractSet, Dict, Optional
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class LoRAResolver(ABC):
|
||||
"""Base class for LoRA adapter resolvers.
|
||||
|
||||
This class defines the interface for resolving and fetching LoRA adapters.
|
||||
Implementations of this class should handle the logic for locating and
|
||||
downloading LoRA adapters from various sources (e.g. S3, cloud storage,
|
||||
etc.).
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def resolve_lora(self, base_model_name: str,
|
||||
lora_name: str) -> Optional[LoRARequest]:
|
||||
"""Abstract method to resolve and fetch a LoRA model adapter.
|
||||
|
||||
Implements logic to locate and download LoRA adapter based on the name.
|
||||
Implementations might fetch from a blob storage or other sources.
|
||||
|
||||
Args:
|
||||
base_model_name: The name/identifier of the base model to resolve.
|
||||
lora_name: The name/identifier of the LoRA model to resolve.
|
||||
|
||||
Returns:
|
||||
Optional[LoRARequest]: The resolved LoRA model information, or None
|
||||
if the LoRA model cannot be found.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class _LoRAResolverRegistry:
|
||||
resolvers: Dict[str, LoRAResolver] = field(default_factory=dict)
|
||||
|
||||
def get_supported_resolvers(self) -> AbstractSet[str]:
|
||||
"""Get all registered resolver names."""
|
||||
return self.resolvers.keys()
|
||||
|
||||
def register_resolver(
|
||||
self,
|
||||
resolver_name: str,
|
||||
resolver: LoRAResolver,
|
||||
) -> None:
|
||||
"""Register a LoRA resolver.
|
||||
Args:
|
||||
resolver_name: Name to register the resolver under.
|
||||
resolver: The LoRA resolver instance to register.
|
||||
"""
|
||||
if resolver_name in self.resolvers:
|
||||
logger.warning(
|
||||
"LoRA resolver %s is already registered, and will be "
|
||||
"overwritten by the new resolver instance %s.", resolver_name,
|
||||
resolver)
|
||||
|
||||
self.resolvers[resolver_name] = resolver
|
||||
|
||||
def get_resolver(self, resolver_name: str) -> LoRAResolver:
|
||||
"""Get a registered resolver instance by name.
|
||||
Args:
|
||||
resolver_name: Name of the resolver to get.
|
||||
Returns:
|
||||
The resolver instance.
|
||||
Raises:
|
||||
KeyError: If the resolver is not found in the registry.
|
||||
"""
|
||||
if resolver_name not in self.resolvers:
|
||||
raise KeyError(
|
||||
f"LoRA resolver '{resolver_name}' not found. "
|
||||
f"Available resolvers: {list(self.resolvers.keys())}")
|
||||
return self.resolvers[resolver_name]
|
||||
|
||||
|
||||
LoRAResolverRegistry = _LoRAResolverRegistry()
|
||||
Reference in New Issue
Block a user