[Lora][Frontend]Add default local directory LoRA resolver plugin. (#16855)
Signed-off-by: jberkhahn <jaberkha@us.ibm.com>
This commit is contained in:
committed by
GitHub
parent
d19110204c
commit
98ea35601c
@@ -68,6 +68,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ALLOW_LONG_MAX_MODEL_LEN: bool = False
|
||||
VLLM_RPC_TIMEOUT: int = 10000 # ms
|
||||
VLLM_PLUGINS: Optional[list[str]] = None
|
||||
VLLM_LORA_RESOLVER_CACHE_DIR: Optional[str] = None
|
||||
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
|
||||
VLLM_USE_TRITON_AWQ: bool = False
|
||||
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
|
||||
@@ -503,6 +504,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
lambda: None if "VLLM_PLUGINS" not in os.environ else os.environ[
|
||||
"VLLM_PLUGINS"].split(","),
|
||||
|
||||
# a local directory to look in for unrecognized LoRA adapters.
|
||||
# only works if plugins are enabled and
|
||||
# VLLM_ALLOW_RUNTIME_LORA_UPDATING is enabled.
|
||||
"VLLM_LORA_RESOLVER_CACHE_DIR":
|
||||
lambda: os.getenv("VLLM_LORA_RESOLVER_CACHE_DIR", None),
|
||||
|
||||
# Enables torch profiler if set. Path to the directory where torch profiler
|
||||
# traces are saved. Note that it must be an absolute path.
|
||||
"VLLM_TORCH_PROFILER_DIR":
|
||||
|
||||
15
vllm/plugins/lora_resolvers/README.md
Normal file
15
vllm/plugins/lora_resolvers/README.md
Normal file
@@ -0,0 +1,15 @@
|
||||
# LoRA Resolver Plugins
|
||||
|
||||
This directory contains vLLM general plugins for dynamically discovering and loading LoRA adapters
|
||||
via the LoRAResolver plugin framework.
|
||||
|
||||
Note that `VLLM_ALLOW_RUNTIME_LORA_UPDATING` must be set to true to allow LoRA resolver plugins
|
||||
to work, and `VLLM_PLUGINS` must be set to include the desired resolver plugins.
|
||||
|
||||
# lora_filesystem_resolver
|
||||
This LoRA Resolver is installed with vLLM by default.
|
||||
To use, set `VLLM_PLUGIN_LORA_CACHE_DIR` to a local directory. When vLLM receives a request
|
||||
for a LoRA adapter `foobar` it doesn't currently recognize, it will look in that local directory
|
||||
for a subdirectory `foobar` containing a LoRA adapter. If such an adapter exists, it will
|
||||
load that adapter, and then service the request as normal. That adapter will then be available
|
||||
for future requests as normal.
|
||||
0
vllm/plugins/lora_resolvers/__init__.py
Normal file
0
vllm/plugins/lora_resolvers/__init__.py
Normal file
49
vllm/plugins/lora_resolvers/filesystem_resolver.py
Normal file
49
vllm/plugins/lora_resolvers/filesystem_resolver.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import json
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.resolver import LoRAResolver, LoRAResolverRegistry
|
||||
|
||||
|
||||
class FilesystemResolver(LoRAResolver):
|
||||
|
||||
def __init__(self, lora_cache_dir: str):
|
||||
self.lora_cache_dir = lora_cache_dir
|
||||
|
||||
async def resolve_lora(self, base_model_name: str,
|
||||
lora_name: str) -> Optional[LoRARequest]:
|
||||
lora_path = os.path.join(self.lora_cache_dir, lora_name)
|
||||
if os.path.exists(lora_path):
|
||||
adapter_config_path = os.path.join(self.lora_cache_dir, lora_name,
|
||||
"adapter_config.json")
|
||||
if os.path.exists(adapter_config_path):
|
||||
with open(adapter_config_path) as file:
|
||||
adapter_config = json.load(file)
|
||||
if adapter_config["peft_type"] == "LORA" and adapter_config[
|
||||
"base_model_name_or_path"] == base_model_name:
|
||||
lora_request = LoRARequest(lora_name=lora_name,
|
||||
lora_int_id=abs(
|
||||
hash(lora_name)),
|
||||
lora_path=lora_path)
|
||||
return lora_request
|
||||
return None
|
||||
|
||||
|
||||
def register_filesystem_resolver():
|
||||
"""Register the filesystem LoRA Resolver with vLLM"""
|
||||
|
||||
lora_cache_dir = envs.VLLM_LORA_RESOLVER_CACHE_DIR
|
||||
if lora_cache_dir:
|
||||
if not os.path.exists(lora_cache_dir) or not os.path.isdir(
|
||||
lora_cache_dir):
|
||||
raise ValueError(
|
||||
"VLLM_LORA_RESOLVER_CACHE_DIR must be set to a valid directory \
|
||||
for Filesystem Resolver plugin to function")
|
||||
fs_resolver = FilesystemResolver(lora_cache_dir)
|
||||
LoRAResolverRegistry.register_resolver("Filesystem Resolver",
|
||||
fs_resolver)
|
||||
|
||||
return
|
||||
Reference in New Issue
Block a user