diff --git a/vllm/entrypoints/pooling/embed/api_router.py b/vllm/entrypoints/pooling/embed/api_router.py index 50a401885..c252bb43c 100644 --- a/vllm/entrypoints/pooling/embed/api_router.py +++ b/vllm/entrypoints/pooling/embed/api_router.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import importlib.util +from functools import lru_cache from http import HTTPStatus from fastapi import APIRouter, Depends, Request @@ -15,9 +17,24 @@ from vllm.entrypoints.pooling.embed.protocol import ( ) from vllm.entrypoints.pooling.embed.serving import OpenAIServingEmbedding from vllm.entrypoints.utils import load_aware_call, with_cancellation +from vllm.logger import init_logger router = APIRouter() +logger = init_logger(__name__) + + +@lru_cache(maxsize=1) +def _get_json_response_cls(): + if importlib.util.find_spec("orjson") is not None: + from fastapi.responses import ORJSONResponse + + return ORJSONResponse + logger.warning_once( + "To make v1/embeddings API fast, please install orjson by `pip install orjson`" + ) + return JSONResponse + def embedding(request: Request) -> OpenAIServingEmbedding | None: return request.app.state.openai_serving_embedding @@ -54,7 +71,7 @@ async def create_embedding( content=generator.model_dump(), status_code=generator.error.code ) elif isinstance(generator, EmbeddingResponse): - return JSONResponse(content=generator.model_dump()) + return _get_json_response_cls()(content=generator.model_dump()) elif isinstance(generator, EmbeddingBytesResponse): return StreamingResponse( content=generator.content,