[Frontend] Split OpenAIServingModels into OpenAIModelRegistry + OpenAIServingModels (#36536)
Signed-off-by: Sage Ahrac <sagiahrak@gmail.com>
This commit is contained in:
@@ -414,11 +414,19 @@ async def init_render_app_state(
|
||||
directly from the :class:`~vllm.config.VllmConfig`.
|
||||
"""
|
||||
from vllm.entrypoints.chat_utils import load_chat_template
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIModelRegistry
|
||||
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
|
||||
from vllm.plugins.io_processors import get_io_processor
|
||||
from vllm.renderers import renderer_from_config
|
||||
|
||||
served_model_names = args.served_model_name or [args.model]
|
||||
model_registry = OpenAIModelRegistry(
|
||||
model_config=vllm_config.model_config,
|
||||
base_model_paths=[
|
||||
BaseModelPath(name=name, model_path=args.model)
|
||||
for name in served_model_names
|
||||
],
|
||||
)
|
||||
|
||||
if args.enable_log_requests:
|
||||
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||
@@ -435,7 +443,7 @@ async def init_render_app_state(
|
||||
model_config=vllm_config.model_config,
|
||||
renderer=renderer,
|
||||
io_processor=io_processor,
|
||||
served_model_names=served_model_names,
|
||||
model_registry=model_registry,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
@@ -447,8 +455,7 @@ async def init_render_app_state(
|
||||
log_error_stack=args.log_error_stack,
|
||||
)
|
||||
|
||||
# Expose models endpoint via the render handler.
|
||||
state.openai_serving_models = state.openai_serving_render
|
||||
state.openai_serving_models = model_registry
|
||||
|
||||
state.vllm_config = vllm_config
|
||||
# Disable stats logging — there is no engine to poll.
|
||||
|
||||
@@ -169,9 +169,7 @@ async def init_generate_state(
|
||||
model_config=engine_client.model_config,
|
||||
renderer=engine_client.renderer,
|
||||
io_processor=engine_client.io_processor,
|
||||
served_model_names=[
|
||||
mp.name for mp in state.openai_serving_models.base_model_paths
|
||||
],
|
||||
model_registry=state.openai_serving_models.registry,
|
||||
request_logger=request_logger,
|
||||
chat_template=resolved_chat_template,
|
||||
chat_template_content_format=args.chat_template_content_format,
|
||||
|
||||
@@ -5,6 +5,7 @@ from asyncio import Lock
|
||||
from collections import defaultdict
|
||||
from http import HTTPStatus
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.protocol import EngineClient
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
ErrorResponse,
|
||||
@@ -27,6 +28,51 @@ from vllm.utils.counter import AtomicCounter
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OpenAIModelRegistry:
|
||||
"""Read-only view of the loaded base models with no engine dependency.
|
||||
|
||||
Suitable for CPU-only / render-only contexts that have no engine client
|
||||
and no LoRA support.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
base_model_paths: list[BaseModelPath],
|
||||
) -> None:
|
||||
self.model_config = model_config
|
||||
self.base_model_paths = base_model_paths
|
||||
|
||||
def is_base_model(self, model_name: str) -> bool:
|
||||
return any(model.name == model_name for model in self.base_model_paths)
|
||||
|
||||
async def check_model(self, model_name: str | None) -> ErrorResponse | None:
|
||||
"""Return an ErrorResponse if model_name is not served, else None."""
|
||||
if not model_name or self.is_base_model(model_name):
|
||||
return None
|
||||
return create_error_response(
|
||||
message=f"The model `{model_name}` does not exist.",
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
param="model",
|
||||
)
|
||||
|
||||
async def show_available_models(self) -> ModelList:
|
||||
"""Show available models (base models only)."""
|
||||
max_model_len = self.model_config.max_model_len
|
||||
return ModelList(
|
||||
data=[
|
||||
ModelCard(
|
||||
id=base_model.name,
|
||||
max_model_len=max_model_len,
|
||||
root=base_model.model_path,
|
||||
permission=[ModelPermission()],
|
||||
)
|
||||
for base_model in self.base_model_paths
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class OpenAIServingModels:
|
||||
"""Shared instance to hold data about the loaded base model(s) and adapters.
|
||||
|
||||
@@ -45,6 +91,11 @@ class OpenAIServingModels:
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.registry = OpenAIModelRegistry(
|
||||
model_config=engine_client.model_config,
|
||||
base_model_paths=base_model_paths,
|
||||
)
|
||||
|
||||
self.engine_client = engine_client
|
||||
self.base_model_paths = base_model_paths
|
||||
|
||||
@@ -79,34 +130,18 @@ class OpenAIServingModels:
|
||||
if isinstance(load_result, ErrorResponse):
|
||||
raise ValueError(load_result.error.message)
|
||||
|
||||
def is_base_model(self, model_name) -> bool:
|
||||
return any(model.name == model_name for model in self.base_model_paths)
|
||||
def is_base_model(self, model_name: str) -> bool:
|
||||
return self.registry.is_base_model(model_name)
|
||||
|
||||
def model_name(self, lora_request: LoRARequest | None = None) -> str:
|
||||
"""Returns the appropriate model name depending on the availability
|
||||
and support of the LoRA or base model.
|
||||
Parameters:
|
||||
- lora: LoRARequest that contain a base_model_name.
|
||||
Returns:
|
||||
- str: The name of the base model or the first available model path.
|
||||
"""
|
||||
if lora_request is not None:
|
||||
return lora_request.lora_name
|
||||
return self.base_model_paths[0].name
|
||||
|
||||
async def show_available_models(self) -> ModelList:
|
||||
"""Show available models. This includes the base model and all adapters."""
|
||||
max_model_len = self.model_config.max_model_len
|
||||
|
||||
model_cards = [
|
||||
ModelCard(
|
||||
id=base_model.name,
|
||||
max_model_len=max_model_len,
|
||||
root=base_model.model_path,
|
||||
permission=[ModelPermission()],
|
||||
)
|
||||
for base_model in self.base_model_paths
|
||||
]
|
||||
"""Show available models. This includes the base model and all
|
||||
adapters."""
|
||||
model_list = await self.registry.show_available_models()
|
||||
lora_cards = [
|
||||
ModelCard(
|
||||
id=lora.lora_name,
|
||||
@@ -118,8 +153,8 @@ class OpenAIServingModels:
|
||||
)
|
||||
for lora in self.lora_requests.values()
|
||||
]
|
||||
model_cards.extend(lora_cards)
|
||||
return ModelList(data=model_cards)
|
||||
model_list.data.extend(lora_cards)
|
||||
return model_list
|
||||
|
||||
async def load_lora_adapter(
|
||||
self, request: LoadLoRAAdapterRequest, base_model_name: str | None = None
|
||||
|
||||
@@ -16,10 +16,8 @@ from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionReque
|
||||
from vllm.entrypoints.openai.completion.protocol import CompletionRequest
|
||||
from vllm.entrypoints.openai.engine.protocol import (
|
||||
ErrorResponse,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
ModelPermission,
|
||||
)
|
||||
from vllm.entrypoints.openai.models.serving import OpenAIModelRegistry
|
||||
from vllm.entrypoints.openai.parser.harmony_utils import (
|
||||
get_developer_message,
|
||||
get_system_message,
|
||||
@@ -46,7 +44,7 @@ class OpenAIServingRender:
|
||||
model_config: ModelConfig,
|
||||
renderer: BaseRenderer,
|
||||
io_processor: Any,
|
||||
served_model_names: list[str],
|
||||
model_registry: OpenAIModelRegistry,
|
||||
*,
|
||||
request_logger: RequestLogger | None,
|
||||
chat_template: str | None,
|
||||
@@ -61,7 +59,7 @@ class OpenAIServingRender:
|
||||
self.model_config = model_config
|
||||
self.renderer = renderer
|
||||
self.io_processor = io_processor
|
||||
self.served_model_names = served_model_names
|
||||
self.model_registry = model_registry
|
||||
self.request_logger = request_logger
|
||||
self.chat_template = chat_template
|
||||
self.chat_template_content_format: ChatTemplateContentFormatOption = (
|
||||
@@ -252,21 +250,6 @@ class OpenAIServingRender:
|
||||
|
||||
return messages, [engine_prompt]
|
||||
|
||||
async def show_available_models(self) -> ModelList:
|
||||
"""Returns the models served by this render server."""
|
||||
max_model_len = self.model_config.max_model_len
|
||||
return ModelList(
|
||||
data=[
|
||||
ModelCard(
|
||||
id=name,
|
||||
max_model_len=max_model_len,
|
||||
root=self.model_config.model,
|
||||
permission=[ModelPermission()],
|
||||
)
|
||||
for name in self.served_model_names
|
||||
]
|
||||
)
|
||||
|
||||
def create_error_response(
|
||||
self,
|
||||
message: str | Exception,
|
||||
@@ -276,23 +259,11 @@ class OpenAIServingRender:
|
||||
) -> ErrorResponse:
|
||||
return create_error_response(message, err_type, status_code, param)
|
||||
|
||||
def _is_model_supported(self, model_name: str) -> bool:
|
||||
"""Simplified from OpenAIServing._is_model_supported (no LoRA support)."""
|
||||
return model_name in self.served_model_names
|
||||
|
||||
async def _check_model(
|
||||
self,
|
||||
request: Any,
|
||||
) -> ErrorResponse | None:
|
||||
"""Simplified from OpenAIServing._check_model (no LoRA support)."""
|
||||
if self._is_model_supported(request.model):
|
||||
return None
|
||||
return self.create_error_response(
|
||||
message=f"The model `{request.model}` does not exist.",
|
||||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND,
|
||||
param="model",
|
||||
)
|
||||
return await self.model_registry.check_model(request.model)
|
||||
|
||||
def _validate_chat_template(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user