[DisaggEverything] Tokens in<>out /generate endpoint (#24261)
Signed-off-by: NickLucche <nlucches@redhat.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -65,6 +65,8 @@ from vllm.entrypoints.openai.protocol import (
|
||||
EmbeddingResponse,
|
||||
ErrorInfo,
|
||||
ErrorResponse,
|
||||
GenerateRequest,
|
||||
GenerateResponse,
|
||||
IOProcessorResponse,
|
||||
PoolingBytesResponse,
|
||||
PoolingRequest,
|
||||
@@ -96,6 +98,7 @@ from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
|
||||
from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses
|
||||
from vllm.entrypoints.openai.serving_score import ServingScores
|
||||
from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization
|
||||
from vllm.entrypoints.openai.serving_tokens import ServingTokens
|
||||
from vllm.entrypoints.openai.serving_transcription import (
|
||||
OpenAIServingTranscription,
|
||||
OpenAIServingTranslation,
|
||||
@@ -357,6 +360,10 @@ def engine_client(request: Request) -> EngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
def generate_tokens(request: Request) -> ServingTokens | None:
|
||||
return request.app.state.serving_tokens
|
||||
|
||||
|
||||
@router.get("/health", response_class=Response)
|
||||
async def health(raw_request: Request) -> Response:
|
||||
"""Health check."""
|
||||
@@ -1228,6 +1235,41 @@ INVOCATION_VALIDATORS = [
|
||||
]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/inference/v1/generate",
|
||||
dependencies=[Depends(validate_json_request)],
|
||||
responses={
|
||||
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
|
||||
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
|
||||
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
|
||||
},
|
||||
)
|
||||
@with_cancellation
|
||||
@load_aware_call
|
||||
async def generate(request: GenerateRequest, raw_request: Request):
|
||||
handler = generate_tokens(raw_request)
|
||||
if handler is None:
|
||||
return base(raw_request).create_error_response(
|
||||
message="The model does not support generate tokens API"
|
||||
)
|
||||
try:
|
||||
generator = await handler.serve_tokens(request, raw_request)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
|
||||
) from e
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(
|
||||
content=generator.model_dump(), status_code=generator.error.code
|
||||
)
|
||||
|
||||
elif isinstance(generator, GenerateResponse):
|
||||
return JSONResponse(content=generator.model_dump())
|
||||
|
||||
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||
|
||||
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
logger.warning_once(
|
||||
"Torch Profiler is enabled in the API server. This should ONLY be "
|
||||
@@ -1629,6 +1671,31 @@ def build_app(args: Namespace) -> FastAPI:
|
||||
)
|
||||
|
||||
app = sagemaker_standards.bootstrap(app)
|
||||
# Optional endpoints
|
||||
if args.tokens_only:
|
||||
|
||||
@app.post("/abort_requests")
|
||||
async def abort_requests(raw_request: Request):
|
||||
"""
|
||||
Abort one or more requests. To be used in a
|
||||
Disaggregated Everything setup.
|
||||
"""
|
||||
try:
|
||||
body = await raw_request.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
detail=f"JSON decode error: {e}",
|
||||
) from e
|
||||
request_ids = body.get("request_ids")
|
||||
if request_ids is None:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.BAD_REQUEST.value,
|
||||
detail="Missing 'request_ids' in request body",
|
||||
)
|
||||
# Abort requests in background
|
||||
asyncio.create_task(engine_client(raw_request).abort(request_ids))
|
||||
return Response(status_code=200)
|
||||
|
||||
return app
|
||||
|
||||
@@ -1851,6 +1918,20 @@ async def init_app_state(
|
||||
if "generate" in supported_tasks
|
||||
else None
|
||||
)
|
||||
state.serving_tokens = (
|
||||
ServingTokens(
|
||||
engine_client,
|
||||
state.openai_serving_models,
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
log_error_stack=args.log_error_stack,
|
||||
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
|
||||
enable_log_outputs=args.enable_log_outputs,
|
||||
force_no_detokenize=args.tokens_only,
|
||||
)
|
||||
if "generate" in supported_tasks
|
||||
else None
|
||||
)
|
||||
|
||||
state.enable_server_load_tracking = args.enable_server_load_tracking
|
||||
state.server_load_metrics = 0
|
||||
|
||||
Reference in New Issue
Block a user