[DisaggEverything] Tokens in<>out /generate endpoint (#24261)

Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Nicolò Lucchesi
2025-11-14 17:58:01 +01:00
committed by GitHub
parent d54a18a47e
commit 6f1e7f7226
12 changed files with 822 additions and 9 deletions

View File

@@ -65,6 +65,8 @@ from vllm.entrypoints.openai.protocol import (
EmbeddingResponse,
ErrorInfo,
ErrorResponse,
GenerateRequest,
GenerateResponse,
IOProcessorResponse,
PoolingBytesResponse,
PoolingRequest,
@@ -96,6 +98,7 @@ from vllm.entrypoints.openai.serving_pooling import OpenAIServingPooling
from vllm.entrypoints.openai.serving_responses import OpenAIServingResponses
from vllm.entrypoints.openai.serving_score import ServingScores
from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization
from vllm.entrypoints.openai.serving_tokens import ServingTokens
from vllm.entrypoints.openai.serving_transcription import (
OpenAIServingTranscription,
OpenAIServingTranslation,
@@ -357,6 +360,10 @@ def engine_client(request: Request) -> EngineClient:
return request.app.state.engine_client
def generate_tokens(request: Request) -> ServingTokens | None:
return request.app.state.serving_tokens
@router.get("/health", response_class=Response)
async def health(raw_request: Request) -> Response:
"""Health check."""
@@ -1228,6 +1235,41 @@ INVOCATION_VALIDATORS = [
]
@router.post(
"/inference/v1/generate",
dependencies=[Depends(validate_json_request)],
responses={
HTTPStatus.OK.value: {"content": {"text/event-stream": {}}},
HTTPStatus.BAD_REQUEST.value: {"model": ErrorResponse},
HTTPStatus.NOT_FOUND.value: {"model": ErrorResponse},
HTTPStatus.INTERNAL_SERVER_ERROR.value: {"model": ErrorResponse},
},
)
@with_cancellation
@load_aware_call
async def generate(request: GenerateRequest, raw_request: Request):
handler = generate_tokens(raw_request)
if handler is None:
return base(raw_request).create_error_response(
message="The model does not support generate tokens API"
)
try:
generator = await handler.serve_tokens(request, raw_request)
except Exception as e:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR.value, detail=str(e)
) from e
if isinstance(generator, ErrorResponse):
return JSONResponse(
content=generator.model_dump(), status_code=generator.error.code
)
elif isinstance(generator, GenerateResponse):
return JSONResponse(content=generator.model_dump())
return StreamingResponse(content=generator, media_type="text/event-stream")
if envs.VLLM_TORCH_PROFILER_DIR:
logger.warning_once(
"Torch Profiler is enabled in the API server. This should ONLY be "
@@ -1629,6 +1671,31 @@ def build_app(args: Namespace) -> FastAPI:
)
app = sagemaker_standards.bootstrap(app)
# Optional endpoints
if args.tokens_only:
@app.post("/abort_requests")
async def abort_requests(raw_request: Request):
"""
Abort one or more requests. To be used in a
Disaggregated Everything setup.
"""
try:
body = await raw_request.json()
except json.JSONDecodeError as e:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
detail=f"JSON decode error: {e}",
) from e
request_ids = body.get("request_ids")
if request_ids is None:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST.value,
detail="Missing 'request_ids' in request body",
)
# Abort requests in background
asyncio.create_task(engine_client(raw_request).abort(request_ids))
return Response(status_code=200)
return app
@@ -1851,6 +1918,20 @@ async def init_app_state(
if "generate" in supported_tasks
else None
)
state.serving_tokens = (
ServingTokens(
engine_client,
state.openai_serving_models,
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
log_error_stack=args.log_error_stack,
enable_prompt_tokens_details=args.enable_prompt_tokens_details,
enable_log_outputs=args.enable_log_outputs,
force_no_detokenize=args.tokens_only,
)
if "generate" in supported_tasks
else None
)
state.enable_server_load_tracking = args.enable_server_load_tracking
state.server_load_metrics = 0