[Frontend] Split OpenAIServingModels into OpenAIModelRegistry + OpenAIServingModels (#36536)

Signed-off-by: Sage Ahrac <sagiahrak@gmail.com>
This commit is contained in:
Sage
2026-03-12 12:29:37 +02:00
committed by GitHub
parent 5a71cdd76e
commit 06e0bc21d2
4 changed files with 73 additions and 62 deletions

View File

@@ -414,11 +414,19 @@ async def init_render_app_state(
directly from the :class:`~vllm.config.VllmConfig`.
"""
from vllm.entrypoints.chat_utils import load_chat_template
from vllm.entrypoints.openai.models.serving import OpenAIModelRegistry
from vllm.entrypoints.serve.render.serving import OpenAIServingRender
from vllm.plugins.io_processors import get_io_processor
from vllm.renderers import renderer_from_config
served_model_names = args.served_model_name or [args.model]
model_registry = OpenAIModelRegistry(
model_config=vllm_config.model_config,
base_model_paths=[
BaseModelPath(name=name, model_path=args.model)
for name in served_model_names
],
)
if args.enable_log_requests:
request_logger = RequestLogger(max_log_len=args.max_log_len)
@@ -435,7 +443,7 @@ async def init_render_app_state(
model_config=vllm_config.model_config,
renderer=renderer,
io_processor=io_processor,
served_model_names=served_model_names,
model_registry=model_registry,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,
@@ -447,8 +455,7 @@ async def init_render_app_state(
log_error_stack=args.log_error_stack,
)
# Expose models endpoint via the render handler.
state.openai_serving_models = state.openai_serving_render
state.openai_serving_models = model_registry
state.vllm_config = vllm_config
# Disable stats logging — there is no engine to poll.

View File

@@ -169,9 +169,7 @@ async def init_generate_state(
model_config=engine_client.model_config,
renderer=engine_client.renderer,
io_processor=engine_client.io_processor,
served_model_names=[
mp.name for mp in state.openai_serving_models.base_model_paths
],
model_registry=state.openai_serving_models.registry,
request_logger=request_logger,
chat_template=resolved_chat_template,
chat_template_content_format=args.chat_template_content_format,

View File

@@ -5,6 +5,7 @@ from asyncio import Lock
from collections import defaultdict
from http import HTTPStatus
from vllm.config import ModelConfig
from vllm.engine.protocol import EngineClient
from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse,
@@ -27,6 +28,51 @@ from vllm.utils.counter import AtomicCounter
logger = init_logger(__name__)
class OpenAIModelRegistry:
"""Read-only view of the loaded base models with no engine dependency.
Suitable for CPU-only / render-only contexts that have no engine client
and no LoRA support.
"""
def __init__(
self,
model_config: ModelConfig,
base_model_paths: list[BaseModelPath],
) -> None:
self.model_config = model_config
self.base_model_paths = base_model_paths
def is_base_model(self, model_name: str) -> bool:
return any(model.name == model_name for model in self.base_model_paths)
async def check_model(self, model_name: str | None) -> ErrorResponse | None:
"""Return an ErrorResponse if model_name is not served, else None."""
if not model_name or self.is_base_model(model_name):
return None
return create_error_response(
message=f"The model `{model_name}` does not exist.",
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND,
param="model",
)
async def show_available_models(self) -> ModelList:
"""Show available models (base models only)."""
max_model_len = self.model_config.max_model_len
return ModelList(
data=[
ModelCard(
id=base_model.name,
max_model_len=max_model_len,
root=base_model.model_path,
permission=[ModelPermission()],
)
for base_model in self.base_model_paths
]
)
class OpenAIServingModels:
"""Shared instance to hold data about the loaded base model(s) and adapters.
@@ -45,6 +91,11 @@ class OpenAIServingModels:
):
super().__init__()
self.registry = OpenAIModelRegistry(
model_config=engine_client.model_config,
base_model_paths=base_model_paths,
)
self.engine_client = engine_client
self.base_model_paths = base_model_paths
@@ -79,34 +130,18 @@ class OpenAIServingModels:
if isinstance(load_result, ErrorResponse):
raise ValueError(load_result.error.message)
def is_base_model(self, model_name) -> bool:
return any(model.name == model_name for model in self.base_model_paths)
def is_base_model(self, model_name: str) -> bool:
return self.registry.is_base_model(model_name)
def model_name(self, lora_request: LoRARequest | None = None) -> str:
"""Returns the appropriate model name depending on the availability
and support of the LoRA or base model.
Parameters:
- lora: LoRARequest that contain a base_model_name.
Returns:
- str: The name of the base model or the first available model path.
"""
if lora_request is not None:
return lora_request.lora_name
return self.base_model_paths[0].name
async def show_available_models(self) -> ModelList:
"""Show available models. This includes the base model and all adapters."""
max_model_len = self.model_config.max_model_len
model_cards = [
ModelCard(
id=base_model.name,
max_model_len=max_model_len,
root=base_model.model_path,
permission=[ModelPermission()],
)
for base_model in self.base_model_paths
]
"""Show available models. This includes the base model and all
adapters."""
model_list = await self.registry.show_available_models()
lora_cards = [
ModelCard(
id=lora.lora_name,
@@ -118,8 +153,8 @@ class OpenAIServingModels:
)
for lora in self.lora_requests.values()
]
model_cards.extend(lora_cards)
return ModelList(data=model_cards)
model_list.data.extend(lora_cards)
return model_list
async def load_lora_adapter(
self, request: LoadLoRAAdapterRequest, base_model_name: str | None = None

View File

@@ -16,10 +16,8 @@ from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionReque
from vllm.entrypoints.openai.completion.protocol import CompletionRequest
from vllm.entrypoints.openai.engine.protocol import (
ErrorResponse,
ModelCard,
ModelList,
ModelPermission,
)
from vllm.entrypoints.openai.models.serving import OpenAIModelRegistry
from vllm.entrypoints.openai.parser.harmony_utils import (
get_developer_message,
get_system_message,
@@ -46,7 +44,7 @@ class OpenAIServingRender:
model_config: ModelConfig,
renderer: BaseRenderer,
io_processor: Any,
served_model_names: list[str],
model_registry: OpenAIModelRegistry,
*,
request_logger: RequestLogger | None,
chat_template: str | None,
@@ -61,7 +59,7 @@ class OpenAIServingRender:
self.model_config = model_config
self.renderer = renderer
self.io_processor = io_processor
self.served_model_names = served_model_names
self.model_registry = model_registry
self.request_logger = request_logger
self.chat_template = chat_template
self.chat_template_content_format: ChatTemplateContentFormatOption = (
@@ -252,21 +250,6 @@ class OpenAIServingRender:
return messages, [engine_prompt]
async def show_available_models(self) -> ModelList:
"""Returns the models served by this render server."""
max_model_len = self.model_config.max_model_len
return ModelList(
data=[
ModelCard(
id=name,
max_model_len=max_model_len,
root=self.model_config.model,
permission=[ModelPermission()],
)
for name in self.served_model_names
]
)
def create_error_response(
self,
message: str | Exception,
@@ -276,23 +259,11 @@ class OpenAIServingRender:
) -> ErrorResponse:
return create_error_response(message, err_type, status_code, param)
def _is_model_supported(self, model_name: str) -> bool:
"""Simplified from OpenAIServing._is_model_supported (no LoRA support)."""
return model_name in self.served_model_names
async def _check_model(
self,
request: Any,
) -> ErrorResponse | None:
"""Simplified from OpenAIServing._check_model (no LoRA support)."""
if self._is_model_supported(request.model):
return None
return self.create_error_response(
message=f"The model `{request.model}` does not exist.",
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND,
param="model",
)
return await self.model_registry.check_model(request.model)
def _validate_chat_template(
self,