[Bugfix][Refactor] Unify model management in frontend (#11660)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
This commit is contained in:
Joe Runde
2024-12-31 18:21:51 -08:00
committed by GitHub
parent 0c6f998554
commit 4db72e57f6
15 changed files with 365 additions and 307 deletions

View File

@@ -58,7 +58,9 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels)
from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
from vllm.entrypoints.openai.serving_score import OpenAIServingScores
from vllm.entrypoints.openai.serving_tokenization import (
@@ -269,6 +271,10 @@ def base(request: Request) -> OpenAIServing:
return tokenization(request)
def models(request: Request) -> OpenAIServingModels:
return request.app.state.openai_serving_models
def chat(request: Request) -> Optional[OpenAIServingChat]:
return request.app.state.openai_serving_chat
@@ -336,10 +342,10 @@ async def detokenize(request: DetokenizeRequest, raw_request: Request):
@router.get("/v1/models")
async def show_available_models(raw_request: Request):
handler = base(raw_request)
handler = models(raw_request)
models = await handler.show_available_models()
return JSONResponse(content=models.model_dump())
models_ = await handler.show_available_models()
return JSONResponse(content=models_.model_dump())
@router.get("/version")
@@ -505,26 +511,22 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
@router.post("/v1/load_lora_adapter")
async def load_lora_adapter(request: LoadLoraAdapterRequest,
raw_request: Request):
for route in [chat, completion, embedding]:
handler = route(raw_request)
if handler is not None:
response = await handler.load_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
handler = models(raw_request)
response = await handler.load_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
return Response(status_code=200, content=response)
@router.post("/v1/unload_lora_adapter")
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
raw_request: Request):
for route in [chat, completion, embedding]:
handler = route(raw_request)
if handler is not None:
response = await handler.unload_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
handler = models(raw_request)
response = await handler.unload_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
return Response(status_code=200, content=response)
@@ -628,13 +630,18 @@ def init_app_state(
resolved_chat_template = load_chat_template(args.chat_template)
logger.info("Using supplied chat template:\n%s", resolved_chat_template)
state.openai_serving_models = OpenAIServingModels(
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
)
# TODO: The chat template is now broken for lora adapters :(
state.openai_serving_chat = OpenAIServingChat(
engine_client,
model_config,
base_model_paths,
state.openai_serving_models,
args.response_role,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
@@ -646,16 +653,14 @@ def init_app_state(
state.openai_serving_completion = OpenAIServingCompletion(
engine_client,
model_config,
base_model_paths,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
state.openai_serving_models,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
) if model_config.runner_type == "generate" else None
state.openai_serving_pooling = OpenAIServingPooling(
engine_client,
model_config,
base_model_paths,
state.openai_serving_models,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
@@ -663,7 +668,7 @@ def init_app_state(
state.openai_serving_embedding = OpenAIServingEmbedding(
engine_client,
model_config,
base_model_paths,
state.openai_serving_models,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
@@ -671,14 +676,13 @@ def init_app_state(
state.openai_serving_scores = OpenAIServingScores(
engine_client,
model_config,
base_model_paths,
state.openai_serving_models,
request_logger=request_logger
) if model_config.task == "score" else None
state.openai_serving_tokenization = OpenAIServingTokenization(
engine_client,
model_config,
base_model_paths,
lora_modules=args.lora_modules,
state.openai_serving_models,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,

View File

@@ -12,7 +12,7 @@ from typing import List, Optional, Sequence, Union, get_args
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.chat_utils import (ChatTemplateContentFormatOption,
validate_chat_template)
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
from vllm.entrypoints.openai.serving_models import (LoRAModulePath,
PromptAdapterPath)
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
from vllm.utils import FlexibleArgumentParser

View File

@@ -20,7 +20,8 @@ from vllm.entrypoints.openai.protocol import (BatchRequestInput,
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_engine import BaseModelPath
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels)
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, random_uuid
from vllm.version import __version__ as VLLM_VERSION
@@ -213,13 +214,17 @@ async def main(args):
request_logger = RequestLogger(max_log_len=args.max_log_len)
# Create the openai serving objects.
openai_serving_models = OpenAIServingModels(
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=None,
prompt_adapters=None,
)
openai_serving_chat = OpenAIServingChat(
engine,
model_config,
base_model_paths,
openai_serving_models,
args.response_role,
lora_modules=None,
prompt_adapters=None,
request_logger=request_logger,
chat_template=None,
chat_template_content_format="auto",
@@ -228,7 +233,7 @@ async def main(args):
openai_serving_embedding = OpenAIServingEmbedding(
engine,
model_config,
base_model_paths,
openai_serving_models,
request_logger=request_logger,
chat_template=None,
chat_template_content_format="auto",

View File

@@ -21,10 +21,8 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage,
DeltaToolCall, ErrorResponse, FunctionCall, PromptTokenUsageInfo,
RequestResponseMetadata, ToolCall, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
LoRAModulePath,
OpenAIServing,
PromptAdapterPath)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager
from vllm.logger import init_logger
from vllm.outputs import CompletionOutput, RequestOutput
@@ -42,11 +40,9 @@ class OpenAIServingChat(OpenAIServing):
self,
engine_client: EngineClient,
model_config: ModelConfig,
base_model_paths: List[BaseModelPath],
models: OpenAIServingModels,
response_role: str,
*,
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
chat_template_content_format: ChatTemplateContentFormatOption,
@@ -57,9 +53,7 @@ class OpenAIServingChat(OpenAIServing):
) -> None:
super().__init__(engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=lora_modules,
prompt_adapters=prompt_adapters,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids)
@@ -126,7 +120,7 @@ class OpenAIServingChat(OpenAIServing):
prompt_adapter_request,
) = self._maybe_get_adapters(request)
model_name = self._get_model_name(lora_request)
model_name = self.models.model_name(lora_request)
tokenizer = await self.engine_client.get_tokenizer(lora_request)

View File

@@ -21,10 +21,8 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
RequestResponseMetadata,
UsageInfo)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
LoRAModulePath,
OpenAIServing,
PromptAdapterPath)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams, SamplingParams
@@ -41,18 +39,14 @@ class OpenAIServingCompletion(OpenAIServing):
self,
engine_client: EngineClient,
model_config: ModelConfig,
base_model_paths: List[BaseModelPath],
models: OpenAIServingModels,
*,
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
):
super().__init__(engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=lora_modules,
prompt_adapters=prompt_adapters,
models=models,
request_logger=request_logger,
return_tokens_as_token_ids=return_tokens_as_token_ids)
diff_sampling_param = self.model_config.get_diff_sampling_param()
@@ -170,7 +164,7 @@ class OpenAIServingCompletion(OpenAIServing):
result_generator = merge_async_iterators(*generators)
model_name = self._get_model_name(lora_request)
model_name = self.models.model_name(lora_request)
num_prompts = len(engine_prompts)
# Similar to the OpenAI API, when n != best_of, we do not stream the

View File

@@ -16,7 +16,8 @@ from vllm.entrypoints.openai.protocol import (EmbeddingChatRequest,
EmbeddingResponse,
EmbeddingResponseData,
ErrorResponse, UsageInfo)
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.logger import init_logger
from vllm.outputs import (EmbeddingOutput, EmbeddingRequestOutput,
PoolingRequestOutput)
@@ -46,7 +47,7 @@ class OpenAIServingEmbedding(OpenAIServing):
self,
engine_client: EngineClient,
model_config: ModelConfig,
base_model_paths: List[BaseModelPath],
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
@@ -54,9 +55,7 @@ class OpenAIServingEmbedding(OpenAIServing):
) -> None:
super().__init__(engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=None,
prompt_adapters=None,
models=models,
request_logger=request_logger)
self.chat_template = chat_template

View File

@@ -1,7 +1,5 @@
import json
import pathlib
from concurrent.futures.thread import ThreadPoolExecutor
from dataclasses import dataclass
from http import HTTPStatus
from typing import (Any, Callable, Dict, Iterable, Iterator, List, Mapping,
Optional, Sequence, Tuple, TypedDict, Union)
@@ -28,13 +26,10 @@ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DetokenizeRequest,
EmbeddingChatRequest,
EmbeddingCompletionRequest,
ErrorResponse,
LoadLoraAdapterRequest,
ModelCard, ModelList,
ModelPermission, ScoreRequest,
ErrorResponse, ScoreRequest,
TokenizeChatRequest,
TokenizeCompletionRequest,
UnloadLoraAdapterRequest)
TokenizeCompletionRequest)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.tool_parsers import ToolParser
# yapf: enable
from vllm.inputs import TokensPrompt
@@ -48,30 +43,10 @@ from vllm.sequence import Logprob
from vllm.tracing import (contains_trace_headers, extract_trace_headers,
log_tracing_disabled_warning)
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.utils import AtomicCounter, is_list_of, make_async, random_uuid
from vllm.utils import is_list_of, make_async, random_uuid
logger = init_logger(__name__)
@dataclass
class BaseModelPath:
name: str
model_path: str
@dataclass
class PromptAdapterPath:
name: str
local_path: str
@dataclass
class LoRAModulePath:
name: str
path: str
base_model_name: Optional[str] = None
CompletionLikeRequest = Union[CompletionRequest, DetokenizeRequest,
EmbeddingCompletionRequest, ScoreRequest,
TokenizeCompletionRequest]
@@ -96,10 +71,8 @@ class OpenAIServing:
self,
engine_client: EngineClient,
model_config: ModelConfig,
base_model_paths: List[BaseModelPath],
models: OpenAIServingModels,
*,
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
return_tokens_as_token_ids: bool = False,
):
@@ -109,35 +82,7 @@ class OpenAIServing:
self.model_config = model_config
self.max_model_len = model_config.max_model_len
self.base_model_paths = base_model_paths
self.lora_id_counter = AtomicCounter(0)
self.lora_requests = []
if lora_modules is not None:
self.lora_requests = [
LoRARequest(lora_name=lora.name,
lora_int_id=i,
lora_path=lora.path,
base_model_name=lora.base_model_name
if lora.base_model_name
and self._is_model_supported(lora.base_model_name)
else self.base_model_paths[0].name)
for i, lora in enumerate(lora_modules, start=1)
]
self.prompt_adapter_requests = []
if prompt_adapters is not None:
for i, prompt_adapter in enumerate(prompt_adapters, start=1):
with pathlib.Path(prompt_adapter.local_path,
"adapter_config.json").open() as f:
adapter_config = json.load(f)
num_virtual_tokens = adapter_config["num_virtual_tokens"]
self.prompt_adapter_requests.append(
PromptAdapterRequest(
prompt_adapter_name=prompt_adapter.name,
prompt_adapter_id=i,
prompt_adapter_local_path=prompt_adapter.local_path,
prompt_adapter_num_virtual_tokens=num_virtual_tokens))
self.models = models
self.request_logger = request_logger
self.return_tokens_as_token_ids = return_tokens_as_token_ids
@@ -150,33 +95,6 @@ class OpenAIServing:
self._tokenize_prompt_input_or_inputs,
executor=self._tokenizer_executor)
async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model."""
model_cards = [
ModelCard(id=base_model.name,
max_model_len=self.max_model_len,
root=base_model.model_path,
permission=[ModelPermission()])
for base_model in self.base_model_paths
]
lora_cards = [
ModelCard(id=lora.lora_name,
root=lora.local_path,
parent=lora.base_model_name if lora.base_model_name else
self.base_model_paths[0].name,
permission=[ModelPermission()])
for lora in self.lora_requests
]
prompt_adapter_cards = [
ModelCard(id=prompt_adapter.prompt_adapter_name,
root=self.base_model_paths[0].name,
permission=[ModelPermission()])
for prompt_adapter in self.prompt_adapter_requests
]
model_cards.extend(lora_cards)
model_cards.extend(prompt_adapter_cards)
return ModelList(data=model_cards)
def create_error_response(
self,
message: str,
@@ -205,11 +123,13 @@ class OpenAIServing:
) -> Optional[ErrorResponse]:
if self._is_model_supported(request.model):
return None
if request.model in [lora.lora_name for lora in self.lora_requests]:
if request.model in [
lora.lora_name for lora in self.models.lora_requests
]:
return None
if request.model in [
prompt_adapter.prompt_adapter_name
for prompt_adapter in self.prompt_adapter_requests
for prompt_adapter in self.models.prompt_adapter_requests
]:
return None
return self.create_error_response(
@@ -223,10 +143,10 @@ class OpenAIServing:
None, PromptAdapterRequest]]:
if self._is_model_supported(request.model):
return None, None
for lora in self.lora_requests:
for lora in self.models.lora_requests:
if request.model == lora.lora_name:
return lora, None
for prompt_adapter in self.prompt_adapter_requests:
for prompt_adapter in self.models.prompt_adapter_requests:
if request.model == prompt_adapter.prompt_adapter_name:
return None, prompt_adapter
# if _check_model has been called earlier, this will be unreachable
@@ -588,91 +508,5 @@ class OpenAIServing:
return logprob.decoded_token
return tokenizer.decode(token_id)
async def _check_load_lora_adapter_request(
self, request: LoadLoraAdapterRequest) -> Optional[ErrorResponse]:
# Check if both 'lora_name' and 'lora_path' are provided
if not request.lora_name or not request.lora_path:
return self.create_error_response(
message="Both 'lora_name' and 'lora_path' must be provided.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)
# Check if the lora adapter with the given name already exists
if any(lora_request.lora_name == request.lora_name
for lora_request in self.lora_requests):
return self.create_error_response(
message=
f"The lora adapter '{request.lora_name}' has already been"
"loaded.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)
return None
async def _check_unload_lora_adapter_request(
self,
request: UnloadLoraAdapterRequest) -> Optional[ErrorResponse]:
# Check if either 'lora_name' or 'lora_int_id' is provided
if not request.lora_name and not request.lora_int_id:
return self.create_error_response(
message=
"either 'lora_name' and 'lora_int_id' needs to be provided.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)
# Check if the lora adapter with the given name exists
if not any(lora_request.lora_name == request.lora_name
for lora_request in self.lora_requests):
return self.create_error_response(
message=
f"The lora adapter '{request.lora_name}' cannot be found.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)
return None
async def load_lora_adapter(
self,
request: LoadLoraAdapterRequest) -> Union[ErrorResponse, str]:
error_check_ret = await self._check_load_lora_adapter_request(request)
if error_check_ret is not None:
return error_check_ret
lora_name, lora_path = request.lora_name, request.lora_path
unique_id = self.lora_id_counter.inc(1)
self.lora_requests.append(
LoRARequest(lora_name=lora_name,
lora_int_id=unique_id,
lora_path=lora_path))
return f"Success: LoRA adapter '{lora_name}' added successfully."
async def unload_lora_adapter(
self,
request: UnloadLoraAdapterRequest) -> Union[ErrorResponse, str]:
error_check_ret = await self._check_unload_lora_adapter_request(request
)
if error_check_ret is not None:
return error_check_ret
lora_name = request.lora_name
self.lora_requests = [
lora_request for lora_request in self.lora_requests
if lora_request.lora_name != lora_name
]
return f"Success: LoRA adapter '{lora_name}' removed successfully."
def _is_model_supported(self, model_name):
return any(model.name == model_name for model in self.base_model_paths)
def _get_model_name(self, lora: Optional[LoRARequest]):
"""
Returns the appropriate model name depending on the availability
and support of the LoRA or base model.
Parameters:
- lora: LoRARequest that contain a base_model_name.
Returns:
- str: The name of the base model or the first available model path.
"""
if lora is not None:
return lora.lora_name
return self.base_model_paths[0].name
return self.models.is_base_model(model_name)

View File

@@ -0,0 +1,210 @@
import json
import pathlib
from dataclasses import dataclass
from http import HTTPStatus
from typing import List, Optional, Union
from vllm.config import ModelConfig
from vllm.entrypoints.openai.protocol import (ErrorResponse,
LoadLoraAdapterRequest,
ModelCard, ModelList,
ModelPermission,
UnloadLoraAdapterRequest)
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.utils import AtomicCounter
@dataclass
class BaseModelPath:
name: str
model_path: str
@dataclass
class PromptAdapterPath:
name: str
local_path: str
@dataclass
class LoRAModulePath:
name: str
path: str
base_model_name: Optional[str] = None
class OpenAIServingModels:
"""Shared instance to hold data about the loaded base model(s) and adapters.
Handles the routes:
- /v1/models
- /v1/load_lora_adapter
- /v1/unload_lora_adapter
"""
def __init__(
self,
model_config: ModelConfig,
base_model_paths: List[BaseModelPath],
*,
lora_modules: Optional[List[LoRAModulePath]] = None,
prompt_adapters: Optional[List[PromptAdapterPath]] = None,
):
super().__init__()
self.base_model_paths = base_model_paths
self.max_model_len = model_config.max_model_len
self.lora_id_counter = AtomicCounter(0)
self.lora_requests = []
if lora_modules is not None:
self.lora_requests = [
LoRARequest(lora_name=lora.name,
lora_int_id=i,
lora_path=lora.path,
base_model_name=lora.base_model_name
if lora.base_model_name
and self.is_base_model(lora.base_model_name) else
self.base_model_paths[0].name)
for i, lora in enumerate(lora_modules, start=1)
]
self.prompt_adapter_requests = []
if prompt_adapters is not None:
for i, prompt_adapter in enumerate(prompt_adapters, start=1):
with pathlib.Path(prompt_adapter.local_path,
"adapter_config.json").open() as f:
adapter_config = json.load(f)
num_virtual_tokens = adapter_config["num_virtual_tokens"]
self.prompt_adapter_requests.append(
PromptAdapterRequest(
prompt_adapter_name=prompt_adapter.name,
prompt_adapter_id=i,
prompt_adapter_local_path=prompt_adapter.local_path,
prompt_adapter_num_virtual_tokens=num_virtual_tokens))
def is_base_model(self, model_name):
return any(model.name == model_name for model in self.base_model_paths)
def model_name(self, lora_request: Optional[LoRARequest] = None) -> str:
"""Returns the appropriate model name depending on the availability
and support of the LoRA or base model.
Parameters:
- lora: LoRARequest that contain a base_model_name.
Returns:
- str: The name of the base model or the first available model path.
"""
if lora_request is not None:
return lora_request.lora_name
return self.base_model_paths[0].name
async def show_available_models(self) -> ModelList:
"""Show available models. This includes the base model and all
adapters"""
model_cards = [
ModelCard(id=base_model.name,
max_model_len=self.max_model_len,
root=base_model.model_path,
permission=[ModelPermission()])
for base_model in self.base_model_paths
]
lora_cards = [
ModelCard(id=lora.lora_name,
root=lora.local_path,
parent=lora.base_model_name if lora.base_model_name else
self.base_model_paths[0].name,
permission=[ModelPermission()])
for lora in self.lora_requests
]
prompt_adapter_cards = [
ModelCard(id=prompt_adapter.prompt_adapter_name,
root=self.base_model_paths[0].name,
permission=[ModelPermission()])
for prompt_adapter in self.prompt_adapter_requests
]
model_cards.extend(lora_cards)
model_cards.extend(prompt_adapter_cards)
return ModelList(data=model_cards)
async def load_lora_adapter(
self,
request: LoadLoraAdapterRequest) -> Union[ErrorResponse, str]:
error_check_ret = await self._check_load_lora_adapter_request(request)
if error_check_ret is not None:
return error_check_ret
lora_name, lora_path = request.lora_name, request.lora_path
unique_id = self.lora_id_counter.inc(1)
self.lora_requests.append(
LoRARequest(lora_name=lora_name,
lora_int_id=unique_id,
lora_path=lora_path))
return f"Success: LoRA adapter '{lora_name}' added successfully."
async def unload_lora_adapter(
self,
request: UnloadLoraAdapterRequest) -> Union[ErrorResponse, str]:
error_check_ret = await self._check_unload_lora_adapter_request(request
)
if error_check_ret is not None:
return error_check_ret
lora_name = request.lora_name
self.lora_requests = [
lora_request for lora_request in self.lora_requests
if lora_request.lora_name != lora_name
]
return f"Success: LoRA adapter '{lora_name}' removed successfully."
async def _check_load_lora_adapter_request(
self, request: LoadLoraAdapterRequest) -> Optional[ErrorResponse]:
# Check if both 'lora_name' and 'lora_path' are provided
if not request.lora_name or not request.lora_path:
return create_error_response(
message="Both 'lora_name' and 'lora_path' must be provided.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)
# Check if the lora adapter with the given name already exists
if any(lora_request.lora_name == request.lora_name
for lora_request in self.lora_requests):
return create_error_response(
message=
f"The lora adapter '{request.lora_name}' has already been"
"loaded.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)
return None
async def _check_unload_lora_adapter_request(
self,
request: UnloadLoraAdapterRequest) -> Optional[ErrorResponse]:
# Check if either 'lora_name' or 'lora_int_id' is provided
if not request.lora_name and not request.lora_int_id:
return create_error_response(
message=
"either 'lora_name' and 'lora_int_id' needs to be provided.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)
# Check if the lora adapter with the given name exists
if not any(lora_request.lora_name == request.lora_name
for lora_request in self.lora_requests):
return create_error_response(
message=
f"The lora adapter '{request.lora_name}' cannot be found.",
err_type="InvalidUserInput",
status_code=HTTPStatus.BAD_REQUEST)
return None
def create_error_response(
message: str,
err_type: str = "BadRequestError",
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
return ErrorResponse(message=message,
type=err_type,
code=status_code.value)

View File

@@ -15,7 +15,8 @@ from vllm.entrypoints.openai.protocol import (ErrorResponse,
PoolingChatRequest,
PoolingRequest, PoolingResponse,
PoolingResponseData, UsageInfo)
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.logger import init_logger
from vllm.outputs import PoolingOutput, PoolingRequestOutput
from vllm.utils import merge_async_iterators
@@ -44,7 +45,7 @@ class OpenAIServingPooling(OpenAIServing):
self,
engine_client: EngineClient,
model_config: ModelConfig,
base_model_paths: List[BaseModelPath],
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
@@ -52,9 +53,7 @@ class OpenAIServingPooling(OpenAIServing):
) -> None:
super().__init__(engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=None,
prompt_adapters=None,
models=models,
request_logger=request_logger)
self.chat_template = chat_template

View File

@@ -10,7 +10,8 @@ from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest,
ScoreResponse, ScoreResponseData,
UsageInfo)
from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.inputs.data import TokensPrompt
from vllm.logger import init_logger
from vllm.outputs import PoolingRequestOutput, ScoringRequestOutput
@@ -50,15 +51,13 @@ class OpenAIServingScores(OpenAIServing):
self,
engine_client: EngineClient,
model_config: ModelConfig,
base_model_paths: List[BaseModelPath],
models: OpenAIServingModels,
*,
request_logger: Optional[RequestLogger],
) -> None:
super().__init__(engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=None,
prompt_adapters=None,
models=models,
request_logger=request_logger)
async def create_score(

View File

@@ -15,9 +15,8 @@ from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
TokenizeRequest,
TokenizeResponse)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (BaseModelPath,
LoRAModulePath,
OpenAIServing)
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.logger import init_logger
logger = init_logger(__name__)
@@ -29,18 +28,15 @@ class OpenAIServingTokenization(OpenAIServing):
self,
engine_client: EngineClient,
model_config: ModelConfig,
base_model_paths: List[BaseModelPath],
models: OpenAIServingModels,
*,
lora_modules: Optional[List[LoRAModulePath]],
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
chat_template_content_format: ChatTemplateContentFormatOption,
) -> None:
super().__init__(engine_client=engine_client,
model_config=model_config,
base_model_paths=base_model_paths,
lora_modules=lora_modules,
prompt_adapters=None,
models=models,
request_logger=request_logger)
self.chat_template = chat_template