533 lines
17 KiB
Python
Executable File
533 lines
17 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
# mypy: ignore-errors
|
|
"""
|
|
vLLM gRPC Server
|
|
|
|
Starts a gRPC server for vLLM using the VllmEngine protocol.
|
|
|
|
Usage:
|
|
python -m vllm.entrypoints.grpc_server --model <model_path>
|
|
|
|
Example:
|
|
python -m vllm.entrypoints.grpc_server \
|
|
--model meta-llama/Llama-2-7b-hf \
|
|
--host 0.0.0.0 \
|
|
--port 50051
|
|
"""
|
|
|
|
import argparse
|
|
import asyncio
|
|
import signal
|
|
import sys
|
|
import time
|
|
from collections.abc import AsyncGenerator
|
|
|
|
import grpc
|
|
import uvloop
|
|
from grpc_reflection.v1alpha import reflection
|
|
|
|
from vllm import SamplingParams, TextPrompt, TokensPrompt
|
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
|
from vllm.entrypoints.utils import log_version_and_model
|
|
from vllm.grpc import vllm_engine_pb2, vllm_engine_pb2_grpc
|
|
from vllm.logger import init_logger
|
|
from vllm.outputs import RequestOutput
|
|
from vllm.sampling_params import RequestOutputKind, StructuredOutputsParams
|
|
from vllm.usage.usage_lib import UsageContext
|
|
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
|
from vllm.v1.engine.async_llm import AsyncLLM
|
|
from vllm.version import __version__ as VLLM_VERSION
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class VllmEngineServicer(vllm_engine_pb2_grpc.VllmEngineServicer):
|
|
"""
|
|
gRPC servicer implementing the VllmEngine service.
|
|
|
|
Handles 6 RPCs:
|
|
- Generate: Streaming text generation
|
|
- Embed: Embeddings (TODO)
|
|
- HealthCheck: Health probe
|
|
- Abort: Cancel requests out-of-band
|
|
- GetModelInfo: Model metadata
|
|
- GetServerInfo: Server state
|
|
"""
|
|
|
|
def __init__(self, async_llm: AsyncLLM, start_time: float):
|
|
"""
|
|
Initialize the servicer.
|
|
|
|
Args:
|
|
async_llm: The AsyncLLM instance
|
|
start_time: The server start time, in seconds since epoch
|
|
"""
|
|
self.async_llm = async_llm
|
|
self.start_time = start_time
|
|
logger.info("VllmEngineServicer initialized")
|
|
|
|
async def Generate(
|
|
self,
|
|
request: vllm_engine_pb2.GenerateRequest,
|
|
context: grpc.aio.ServicerContext,
|
|
) -> AsyncGenerator[vllm_engine_pb2.GenerateResponse, None]:
|
|
"""
|
|
Handle streaming generation requests.
|
|
|
|
Args:
|
|
request: The GenerateRequest protobuf
|
|
context: gRPC context
|
|
|
|
Yields:
|
|
GenerateResponse protobuf messages (streaming)
|
|
"""
|
|
request_id = request.request_id
|
|
logger.debug("Generate request %s received.", request_id)
|
|
|
|
try:
|
|
# Extract tokenized input
|
|
if request.WhichOneof("input") == "tokenized":
|
|
prompt: TokensPrompt = {
|
|
"prompt_token_ids": list(request.tokenized.input_ids)
|
|
}
|
|
if request.tokenized.original_text:
|
|
prompt["prompt"] = request.tokenized.original_text
|
|
else:
|
|
prompt: TextPrompt = {"prompt": request.text}
|
|
|
|
# Build sampling params with detokenize=False
|
|
sampling_params = self._sampling_params_from_proto(
|
|
request.sampling_params, stream=request.stream
|
|
)
|
|
|
|
async for output in self.async_llm.generate(
|
|
prompt=prompt,
|
|
sampling_params=sampling_params,
|
|
request_id=request_id,
|
|
):
|
|
# Convert vLLM output to protobuf
|
|
# For streaming, always send chunks
|
|
if request.stream:
|
|
yield self._chunk_response(output)
|
|
|
|
# Send complete response when finished
|
|
if output.finished:
|
|
yield self._complete_response(output)
|
|
|
|
except ValueError as e:
|
|
# Invalid request error (equiv to 400).
|
|
await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
|
|
except Exception as e:
|
|
logger.exception("Error in Generate for request %s", request_id)
|
|
await context.abort(grpc.StatusCode.INTERNAL, str(e))
|
|
|
|
async def Embed(
|
|
self,
|
|
request: vllm_engine_pb2.EmbedRequest,
|
|
context: grpc.aio.ServicerContext,
|
|
) -> vllm_engine_pb2.EmbedResponse:
|
|
"""
|
|
Handle embedding requests.
|
|
|
|
TODO: Implement in Phase 4
|
|
|
|
Args:
|
|
request: The EmbedRequest protobuf
|
|
context: gRPC context
|
|
|
|
Returns:
|
|
EmbedResponse protobuf
|
|
"""
|
|
logger.warning("Embed RPC not yet implemented")
|
|
await context.abort(
|
|
grpc.StatusCode.UNIMPLEMENTED, "Embed RPC not yet implemented"
|
|
)
|
|
|
|
async def HealthCheck(
|
|
self,
|
|
request: vllm_engine_pb2.HealthCheckRequest,
|
|
context: grpc.aio.ServicerContext,
|
|
) -> vllm_engine_pb2.HealthCheckResponse:
|
|
"""
|
|
Handle health check requests.
|
|
|
|
Args:
|
|
request: The HealthCheckRequest protobuf
|
|
context: gRPC context
|
|
|
|
Returns:
|
|
HealthCheckResponse protobuf
|
|
"""
|
|
is_healthy = not self.async_llm.errored
|
|
message = "Health" if is_healthy else "Engine is not alive"
|
|
|
|
logger.debug("HealthCheck request: healthy=%s, message=%s", is_healthy, message)
|
|
|
|
return vllm_engine_pb2.HealthCheckResponse(healthy=is_healthy, message=message)
|
|
|
|
async def Abort(
|
|
self,
|
|
request: vllm_engine_pb2.AbortRequest,
|
|
context: grpc.aio.ServicerContext,
|
|
) -> vllm_engine_pb2.AbortResponse:
|
|
"""
|
|
Out-of-band abort requests.
|
|
|
|
Args:
|
|
request: The AbortRequest protobuf
|
|
context: gRPC context
|
|
|
|
Returns:
|
|
AbortResponse protobuf
|
|
"""
|
|
request_ids = request.request_ids
|
|
logger.debug("Abort requests: %s", request_ids)
|
|
|
|
await self.async_llm.abort(request_ids)
|
|
return vllm_engine_pb2.AbortResponse()
|
|
|
|
async def GetModelInfo(
|
|
self,
|
|
request: vllm_engine_pb2.GetModelInfoRequest,
|
|
context: grpc.aio.ServicerContext,
|
|
) -> vllm_engine_pb2.GetModelInfoResponse:
|
|
"""
|
|
Handle model info requests.
|
|
|
|
Args:
|
|
request: The GetModelInfoRequest protobuf
|
|
context: gRPC context
|
|
|
|
Returns:
|
|
GetModelInfoResponse protobuf
|
|
"""
|
|
model_config = self.async_llm.model_config
|
|
|
|
return vllm_engine_pb2.GetModelInfoResponse(
|
|
model_path=model_config.model,
|
|
is_generation=model_config.runner_type == "generate",
|
|
max_context_length=model_config.max_model_len,
|
|
vocab_size=model_config.get_vocab_size(),
|
|
supports_vision=model_config.is_multimodal_model,
|
|
)
|
|
|
|
async def GetServerInfo(
|
|
self,
|
|
request: vllm_engine_pb2.GetServerInfoRequest,
|
|
context: grpc.aio.ServicerContext,
|
|
) -> vllm_engine_pb2.GetServerInfoResponse:
|
|
"""
|
|
Handle server info requests.
|
|
|
|
Args:
|
|
request: The GetServerInfoRequest protobuf
|
|
context: gRPC context
|
|
|
|
Returns:
|
|
GetServerInfoResponse protobuf
|
|
"""
|
|
num_requests = self.async_llm.output_processor.get_num_unfinished_requests()
|
|
|
|
return vllm_engine_pb2.GetServerInfoResponse(
|
|
active_requests=num_requests,
|
|
is_paused=False, # TODO
|
|
last_receive_timestamp=time.time(), # TODO looks wrong?
|
|
uptime_seconds=time.time() - self.start_time,
|
|
server_type="vllm-grpc",
|
|
)
|
|
|
|
# ========== Helper methods ==========
|
|
|
|
@staticmethod
|
|
def _sampling_params_from_proto(
|
|
params: vllm_engine_pb2.SamplingParams, stream: bool = True
|
|
) -> SamplingParams:
|
|
"""
|
|
Convert protobuf SamplingParams to vLLM SamplingParams.
|
|
|
|
Args:
|
|
params: Protobuf SamplingParams message
|
|
stream: Whether streaming is enabled
|
|
|
|
Returns:
|
|
vLLM SamplingParams with detokenize=False and structured_outputs
|
|
"""
|
|
# Build stop sequences
|
|
stop = list(params.stop) if params.stop else None
|
|
stop_token_ids = list(params.stop_token_ids) if params.stop_token_ids else None
|
|
|
|
# Handle structured outputs constraints
|
|
structured_outputs = None
|
|
constraint_field = params.WhichOneof("constraint")
|
|
if constraint_field:
|
|
if constraint_field == "json_schema":
|
|
structured_outputs = StructuredOutputsParams(json=params.json_schema)
|
|
elif constraint_field == "regex":
|
|
structured_outputs = StructuredOutputsParams(regex=params.regex)
|
|
elif constraint_field == "grammar":
|
|
structured_outputs = StructuredOutputsParams(grammar=params.grammar)
|
|
elif constraint_field == "structural_tag":
|
|
structured_outputs = StructuredOutputsParams(
|
|
structural_tag=params.structural_tag
|
|
)
|
|
elif constraint_field == "json_object":
|
|
structured_outputs = StructuredOutputsParams(
|
|
json_object=params.json_object
|
|
)
|
|
elif constraint_field == "choice":
|
|
structured_outputs = StructuredOutputsParams(
|
|
choice=list(params.choice.choices)
|
|
)
|
|
|
|
# Create SamplingParams
|
|
# output_kind=DELTA: Return only new tokens in each chunk (for streaming)
|
|
return SamplingParams(
|
|
temperature=params.temperature if params.HasField("temperature") else 1.0,
|
|
top_p=params.top_p if params.top_p != 0.0 else 1.0,
|
|
top_k=params.top_k,
|
|
min_p=params.min_p,
|
|
frequency_penalty=params.frequency_penalty,
|
|
presence_penalty=params.presence_penalty,
|
|
repetition_penalty=params.repetition_penalty
|
|
if params.repetition_penalty != 0.0
|
|
else 1.0,
|
|
max_tokens=params.max_tokens if params.HasField("max_tokens") else None,
|
|
min_tokens=params.min_tokens,
|
|
stop=stop,
|
|
stop_token_ids=stop_token_ids,
|
|
skip_special_tokens=params.skip_special_tokens,
|
|
spaces_between_special_tokens=params.spaces_between_special_tokens,
|
|
ignore_eos=params.ignore_eos,
|
|
n=params.n if params.n > 0 else 1,
|
|
logprobs=params.logprobs if params.HasField("logprobs") else None,
|
|
prompt_logprobs=params.prompt_logprobs
|
|
if params.HasField("prompt_logprobs")
|
|
else None,
|
|
seed=params.seed if params.HasField("seed") else None,
|
|
include_stop_str_in_output=params.include_stop_str_in_output,
|
|
logit_bias=dict(params.logit_bias) if params.logit_bias else None,
|
|
truncate_prompt_tokens=params.truncate_prompt_tokens
|
|
if params.HasField("truncate_prompt_tokens")
|
|
else None,
|
|
structured_outputs=structured_outputs,
|
|
# detokenize must be True if stop strings are used
|
|
detokenize=bool(stop),
|
|
output_kind=RequestOutputKind.DELTA
|
|
if stream
|
|
else RequestOutputKind.FINAL_ONLY,
|
|
)
|
|
|
|
@staticmethod
|
|
def _chunk_response(output: RequestOutput) -> vllm_engine_pb2.GenerateResponse:
|
|
"""
|
|
Build a streaming chunk response from vLLM output.
|
|
When output_kind=DELTA, vLLM returns only new tokens automatically.
|
|
|
|
Args:
|
|
output: vLLM RequestOutput (with delta tokens when output_kind=DELTA)
|
|
|
|
Returns:
|
|
GenerateResponse with chunk field set
|
|
"""
|
|
# Get the completion output (first one if n > 1)
|
|
completion = output.outputs[0] if output.outputs else None
|
|
|
|
if completion is None:
|
|
# Empty chunk
|
|
return vllm_engine_pb2.GenerateResponse(
|
|
chunk=vllm_engine_pb2.GenerateStreamChunk(
|
|
token_ids=[],
|
|
prompt_tokens=0,
|
|
completion_tokens=0,
|
|
cached_tokens=0,
|
|
),
|
|
)
|
|
|
|
# When output_kind=DELTA, completion.token_ids contains only new tokens
|
|
# vLLM handles the delta logic internally
|
|
# completion_tokens = delta count (client will accumulate)
|
|
return vllm_engine_pb2.GenerateResponse(
|
|
chunk=vllm_engine_pb2.GenerateStreamChunk(
|
|
token_ids=completion.token_ids,
|
|
prompt_tokens=len(output.prompt_token_ids)
|
|
if output.prompt_token_ids
|
|
else 0,
|
|
completion_tokens=len(completion.token_ids), # Delta count
|
|
cached_tokens=output.num_cached_tokens,
|
|
),
|
|
)
|
|
|
|
@staticmethod
|
|
def _complete_response(output: RequestOutput) -> vllm_engine_pb2.GenerateResponse:
|
|
"""
|
|
Build a final completion response from vLLM output.
|
|
|
|
Args:
|
|
output: vLLM RequestOutput (finished=True)
|
|
|
|
Returns:
|
|
GenerateResponse with complete field set
|
|
"""
|
|
# Get the completion output (first one if n > 1)
|
|
completion = output.outputs[0] if output.outputs else None
|
|
|
|
if completion is None:
|
|
# Empty completion
|
|
return vllm_engine_pb2.GenerateResponse(
|
|
complete=vllm_engine_pb2.GenerateComplete(
|
|
output_ids=[],
|
|
finish_reason="error",
|
|
prompt_tokens=0,
|
|
completion_tokens=0,
|
|
cached_tokens=0,
|
|
),
|
|
)
|
|
|
|
# Build complete response
|
|
# When streaming (DELTA mode): completion.token_ids will be empty/last delta
|
|
# When non-streaming (FINAL_ONLY mode): completion.token_ids has all tokens
|
|
# Client will accumulate token counts for streaming
|
|
return vllm_engine_pb2.GenerateResponse(
|
|
complete=vllm_engine_pb2.GenerateComplete(
|
|
output_ids=completion.token_ids,
|
|
finish_reason=completion.finish_reason or "stop",
|
|
prompt_tokens=len(output.prompt_token_ids)
|
|
if output.prompt_token_ids
|
|
else 0,
|
|
completion_tokens=len(completion.token_ids),
|
|
cached_tokens=output.num_cached_tokens,
|
|
),
|
|
)
|
|
|
|
|
|
async def serve_grpc(args: argparse.Namespace):
|
|
"""
|
|
Main serving function.
|
|
|
|
Args:
|
|
args: Parsed command line arguments
|
|
"""
|
|
log_version_and_model(logger, VLLM_VERSION, args.model)
|
|
logger.info("vLLM gRPC server args: %s", args)
|
|
|
|
start_time = time.time()
|
|
|
|
# Create engine args
|
|
engine_args = AsyncEngineArgs.from_cli_args(args)
|
|
|
|
# Build vLLM config
|
|
vllm_config = engine_args.create_engine_config(
|
|
usage_context=UsageContext.OPENAI_API_SERVER
|
|
)
|
|
|
|
# Create AsyncLLM
|
|
async_llm = AsyncLLM.from_vllm_config(
|
|
vllm_config=vllm_config,
|
|
usage_context=UsageContext.OPENAI_API_SERVER,
|
|
enable_log_requests=args.enable_log_requests,
|
|
disable_log_stats=args.disable_log_stats_server,
|
|
)
|
|
|
|
# Create servicer
|
|
servicer = VllmEngineServicer(async_llm, start_time)
|
|
|
|
# Create gRPC server
|
|
server = grpc.aio.server(
|
|
options=[
|
|
("grpc.max_send_message_length", -1),
|
|
("grpc.max_receive_message_length", -1),
|
|
],
|
|
)
|
|
|
|
# Add servicer to server
|
|
vllm_engine_pb2_grpc.add_VllmEngineServicer_to_server(servicer, server)
|
|
|
|
# Enable reflection for grpcurl and other tools
|
|
service_names = (
|
|
vllm_engine_pb2.DESCRIPTOR.services_by_name["VllmEngine"].full_name,
|
|
reflection.SERVICE_NAME,
|
|
)
|
|
reflection.enable_server_reflection(service_names, server)
|
|
|
|
# Bind to address
|
|
address = f"{args.host}:{args.port}"
|
|
server.add_insecure_port(address)
|
|
|
|
# Start server
|
|
await server.start()
|
|
logger.info("vLLM gRPC server started on %s", address)
|
|
logger.info("Server is ready to accept requests")
|
|
|
|
# Handle shutdown signals
|
|
loop = asyncio.get_running_loop()
|
|
stop_event = asyncio.Event()
|
|
|
|
def signal_handler():
|
|
logger.info("Received shutdown signal")
|
|
stop_event.set()
|
|
|
|
for sig in (signal.SIGTERM, signal.SIGINT):
|
|
loop.add_signal_handler(sig, signal_handler)
|
|
|
|
# Serve until shutdown signal
|
|
try:
|
|
await stop_event.wait()
|
|
except KeyboardInterrupt:
|
|
logger.info("Interrupted by user")
|
|
finally:
|
|
logger.info("Shutting down vLLM gRPC server...")
|
|
|
|
# Stop gRPC server
|
|
await server.stop(grace=5.0)
|
|
logger.info("gRPC server stopped")
|
|
|
|
# Shutdown AsyncLLM
|
|
async_llm.shutdown()
|
|
logger.info("AsyncLLM engine stopped")
|
|
|
|
logger.info("Shutdown complete")
|
|
|
|
|
|
def main():
|
|
"""Main entry point."""
|
|
parser = FlexibleArgumentParser(
|
|
description="vLLM gRPC Server",
|
|
)
|
|
|
|
# Server args
|
|
parser.add_argument(
|
|
"--host",
|
|
type=str,
|
|
default="0.0.0.0",
|
|
help="Host to bind gRPC server to",
|
|
)
|
|
parser.add_argument(
|
|
"--port",
|
|
type=int,
|
|
default=50051,
|
|
help="Port to bind gRPC server to",
|
|
)
|
|
parser.add_argument(
|
|
"--disable-log-stats-server",
|
|
action="store_true",
|
|
help="Disable stats logging on server side",
|
|
)
|
|
|
|
# Add vLLM engine args
|
|
parser = AsyncEngineArgs.add_cli_args(parser)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Run server
|
|
try:
|
|
uvloop.run(serve_grpc(args))
|
|
except Exception as e:
|
|
logger.exception("Server failed: %s", e)
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|