feat(grpc): extract gRPC servicer into smg-grpc-servicer package, add --grpc flag to vllm serve (#36169)

Signed-off-by: Chang Su <chang.s.su@oracle.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Chang Su
2026-03-10 03:29:59 -07:00
committed by GitHub
parent ddbb0d230a
commit 507ddbe992
13 changed files with 60 additions and 1245 deletions

View File

@@ -9,7 +9,6 @@ requires = [
"torch == 2.10.0",
"wheel",
"jinja2",
"grpcio-tools==1.78.0",
]
build-backend = "setuptools.build_meta"
@@ -57,10 +56,6 @@ include = ["vllm*"]
"vllm/third_party/**" = ["ALL"]
"vllm/version.py" = ["F401"]
"vllm/_version.py" = ["ALL"]
# Exclude generated protobuf files
"vllm/grpc/*_pb2.py" = ["ALL"]
"vllm/grpc/*_pb2_grpc.py" = ["ALL"]
"vllm/grpc/*_pb2.pyi" = ["ALL"]
[tool.ruff.lint]
select = [

View File

@@ -10,4 +10,3 @@ jinja2>=3.1.6
regex
build
protobuf >= 5.29.6, !=6.30.*, !=6.31.*, !=6.32.*, !=6.33.0.*, !=6.33.1.*, !=6.33.2.*, !=6.33.3.*, !=6.33.4.*
grpcio-tools==1.78.0 # Required for grpc entrypoints

View File

@@ -51,8 +51,6 @@ openai-harmony >= 0.0.3 # Required for gpt-oss
anthropic >= 0.71.0
model-hosting-container-standards >= 0.1.13, < 1.0.0
mcp
grpcio
grpcio-reflection
opentelemetry-sdk >= 1.27.0
opentelemetry-api >= 1.27.0
opentelemetry-exporter-otlp >= 1.27.0

View File

@@ -4,7 +4,6 @@
# The version of gRPC libraries should be consistent with each other
grpcio==1.78.0
grpcio-reflection==1.78.0
grpcio-tools==1.78.0
numba == 0.61.2 # Required for N-gram speculative decoding

View File

@@ -51,7 +51,6 @@ tritonclient>=2.51.0
# The version of gRPC libraries should be consistent with each other
grpcio==1.78.0
grpcio-reflection==1.78.0
grpcio-tools==1.78.0
arctic-inference == 0.1.1 # Required for suffix decoding test
numba == 0.61.2 # Required for N-gram speculative decoding

View File

@@ -289,13 +289,10 @@ grpcio==1.78.0
# via
# -r requirements/test.in
# grpcio-reflection
# grpcio-tools
# ray
# tensorboard
grpcio-reflection==1.78.0
# via -r requirements/test.in
grpcio-tools==1.78.0
# via -r requirements/test.in
h11==0.14.0
# via
# httpcore
@@ -765,7 +762,6 @@ protobuf==6.33.2
# google-api-core
# googleapis-common-protos
# grpcio-reflection
# grpcio-tools
# opentelemetry-proto
# proto-plus
# ray
@@ -1045,7 +1041,6 @@ sentry-sdk==2.52.0
# via wandb
setuptools==77.0.3
# via
# grpcio-tools
# lightning-utilities
# pytablewriter
# tensorboard

View File

@@ -18,8 +18,6 @@ import torch
from packaging.version import Version, parse
from setuptools import Extension, setup
from setuptools.command.build_ext import build_ext
from setuptools.command.build_py import build_py
from setuptools.command.develop import develop
from setuptools_scm import get_version
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
@@ -81,81 +79,6 @@ def is_freethreaded():
return bool(sysconfig.get_config_var("Py_GIL_DISABLED"))
def compile_grpc_protos():
"""Compile gRPC protobuf definitions during build.
This generates *_pb2.py, *_pb2_grpc.py, and *_pb2.pyi files from
the vllm_engine.proto definition.
"""
try:
from grpc_tools import protoc
except ImportError:
logger.warning(
"grpcio-tools not installed, skipping gRPC proto compilation. "
"gRPC server functionality will not be available."
)
return False
proto_file = ROOT_DIR / "vllm" / "grpc" / "vllm_engine.proto"
if not proto_file.exists():
logger.warning("Proto file not found at %s, skipping compilation", proto_file)
return False
logger.info("Compiling gRPC protobuf: %s", proto_file)
result = protoc.main(
[
"grpc_tools.protoc",
f"--proto_path={ROOT_DIR}",
f"--python_out={ROOT_DIR}",
f"--grpc_python_out={ROOT_DIR}",
f"--pyi_out={ROOT_DIR}",
str(proto_file),
]
)
if result != 0:
logger.error("protoc failed with exit code %s", result)
return False
# Add SPDX headers and mypy ignore to generated files
spdx_header = (
"# SPDX-License-Identifier: Apache-2.0\n"
"# SPDX-FileCopyrightText: Copyright contributors to the vLLM project\n"
"# mypy: ignore-errors\n"
)
grpc_dir = ROOT_DIR / "vllm" / "grpc"
for generated_file in [
grpc_dir / "vllm_engine_pb2.py",
grpc_dir / "vllm_engine_pb2_grpc.py",
grpc_dir / "vllm_engine_pb2.pyi",
]:
if generated_file.exists():
content = generated_file.read_text()
if not content.startswith("# SPDX-License-Identifier"):
generated_file.write_text(spdx_header + content)
logger.info("gRPC protobuf compilation successful")
return True
class BuildPyAndGenerateGrpc(build_py):
"""Build Python modules and generate gRPC stubs from proto files."""
def run(self):
compile_grpc_protos()
super().run()
class DevelopAndGenerateGrpc(develop):
"""Develop mode that also generates gRPC stubs from proto files."""
def run(self):
compile_grpc_protos()
super().run()
class CMakeExtension(Extension):
def __init__(self, name: str, cmake_lists_dir: str = ".", **kwa) -> None:
super().__init__(name, sources=[], py_limited_api=not is_freethreaded(), **kwa)
@@ -1028,17 +951,12 @@ if _no_device():
ext_modules = []
if not ext_modules:
cmdclass = {
"build_py": BuildPyAndGenerateGrpc,
"develop": DevelopAndGenerateGrpc,
}
cmdclass = {}
else:
cmdclass = {
"build_ext": precompiled_build_ext
if envs.VLLM_USE_PRECOMPILED
else cmake_build_ext,
"build_py": BuildPyAndGenerateGrpc,
"develop": DevelopAndGenerateGrpc,
}
setup(
@@ -1064,6 +982,8 @@ setup(
"petit-kernel": ["petit-kernel"],
# Optional deps for Helion kernel development
"helion": ["helion"],
# Optional deps for gRPC server (vllm serve --grpc)
"grpc": ["smg-grpc-servicer >= 0.4.2"],
# Optional deps for OpenTelemetry tracing
"otel": [
"opentelemetry-sdk>=1.26.0",

View File

@@ -1,428 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
End-to-end tests for the vLLM gRPC server.
"""
import asyncio
import socket
import subprocess
import sys
import time
import grpc
import pytest
import pytest_asyncio
from vllm.grpc import vllm_engine_pb2, vllm_engine_pb2_grpc
# Use a small model for fast testing
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
def find_free_port() -> int:
"""Find a free port on localhost."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
s.listen(1)
port = s.getsockname()[1]
return port
async def wait_for_server(port: int, timeout: float = 60.0) -> bool:
"""Wait for the gRPC server to be ready by trying health checks."""
start_time = time.time()
print("waiting for server to start...")
while time.time() - start_time < timeout:
try:
channel = grpc.aio.insecure_channel(f"localhost:{port}")
stub = vllm_engine_pb2_grpc.VllmEngineStub(channel)
request = vllm_engine_pb2.HealthCheckRequest()
response = await stub.HealthCheck(request, timeout=5.0)
await channel.close()
if response.healthy:
print("server returned healthy=True")
return True
except Exception:
await asyncio.sleep(0.5)
return False
class GrpcServerProcess:
"""Manages a gRPC server running in a subprocess."""
def __init__(self):
self.process: subprocess.Popen | None = None
self.port: int | None = None
async def start(self):
"""Start the gRPC server process."""
self.port = find_free_port()
# Start the server as a subprocess
self.process = subprocess.Popen(
[
sys.executable,
"-m",
"vllm.entrypoints.grpc_server",
"--model",
MODEL_NAME,
"--host",
"localhost",
"--port",
str(self.port),
"--max-num-batched-tokens",
"512",
"--disable-log-stats-server",
],
)
# Wait for server to be ready
if not await wait_for_server(self.port):
self.stop()
raise RuntimeError("gRPC server failed to start within timeout")
def stop(self):
"""Stop the gRPC server process."""
if self.process:
self.process.terminate()
try:
self.process.wait(timeout=10)
except subprocess.TimeoutExpired:
self.process.kill()
self.process.wait()
@pytest_asyncio.fixture(scope="module")
async def grpc_server():
"""Fixture providing a running gRPC server in a subprocess."""
server = GrpcServerProcess()
await server.start()
yield server
server.stop()
@pytest_asyncio.fixture
async def grpc_client(grpc_server):
"""Fixture providing a gRPC client connected to the server."""
channel = grpc.aio.insecure_channel(f"localhost:{grpc_server.port}")
stub = vllm_engine_pb2_grpc.VllmEngineStub(channel)
yield stub
await channel.close()
@pytest.mark.asyncio
async def test_health_check(grpc_client):
"""Test the HealthCheck RPC."""
request = vllm_engine_pb2.HealthCheckRequest()
response = await grpc_client.HealthCheck(request)
assert response.healthy is True
assert response.message == "Health"
@pytest.mark.asyncio
async def test_get_model_info(grpc_client):
"""Test the GetModelInfo RPC."""
request = vllm_engine_pb2.GetModelInfoRequest()
response = await grpc_client.GetModelInfo(request)
assert response.model_path == MODEL_NAME
assert response.is_generation is True
assert response.max_context_length > 0
assert response.vocab_size > 0
assert response.supports_vision is False
@pytest.mark.asyncio
async def test_get_server_info(grpc_client):
"""Test the GetServerInfo RPC."""
request = vllm_engine_pb2.GetServerInfoRequest()
response = await grpc_client.GetServerInfo(request)
assert response.active_requests >= 0
assert response.is_paused is False
assert response.uptime_seconds >= 0
assert response.server_type == "vllm-grpc"
assert response.last_receive_timestamp > 0
@pytest.mark.asyncio
async def test_generate_non_streaming(grpc_client):
"""Test the Generate RPC in non-streaming mode."""
# Create a simple request
request = vllm_engine_pb2.GenerateRequest(
request_id="test-non-streaming-1",
tokenized=vllm_engine_pb2.TokenizedInput(
original_text="Hello, my name is",
input_ids=[15496, 11, 616, 1438, 318], # GPT-2 tokens for the prompt
),
sampling_params=vllm_engine_pb2.SamplingParams(
temperature=0.0,
max_tokens=10,
n=1,
),
stream=False,
)
# Collect all responses
responses = []
async for response in grpc_client.Generate(request):
responses.append(response)
# Should have exactly one response (complete)
assert len(responses) == 1
# Check the response
final_response = responses[0]
assert final_response.HasField("complete")
complete = final_response.complete
assert len(complete.output_ids) > 0
assert complete.finish_reason in ["stop", "length"]
assert complete.prompt_tokens > 0
assert complete.completion_tokens > 0
@pytest.mark.asyncio
async def test_generate_streaming(grpc_client):
"""Test the Generate RPC in streaming mode."""
request = vllm_engine_pb2.GenerateRequest(
request_id="test-streaming-1",
tokenized=vllm_engine_pb2.TokenizedInput(
original_text="The capital of France is",
input_ids=[464, 3139, 286, 4881, 318], # GPT-2 tokens
),
sampling_params=vllm_engine_pb2.SamplingParams(
temperature=0.0, max_tokens=10, n=1
),
stream=True,
)
# Collect all responses
chunks = []
complete_response = None
async for response in grpc_client.Generate(request):
if response.HasField("chunk"):
chunks.append(response.chunk)
elif response.HasField("complete"):
complete_response = response.complete
# Should have received some chunks
assert len(chunks) >= 0 # May have 0 chunks if generation is very fast
# Should have a final complete response
assert complete_response is not None
assert complete_response.finish_reason in ["stop", "length"]
assert complete_response.prompt_tokens > 0
# Verify chunk structure
for chunk in chunks:
assert chunk.prompt_tokens > 0
assert chunk.completion_tokens >= 0
@pytest.mark.asyncio
async def test_generate_with_different_sampling_params(grpc_client):
"""Test Generate with various sampling parameters."""
# Test with temperature
request = vllm_engine_pb2.GenerateRequest(
request_id="test-sampling-temp",
tokenized=vllm_engine_pb2.TokenizedInput(
original_text="Hello",
input_ids=[15496],
),
sampling_params=vllm_engine_pb2.SamplingParams(
temperature=0.8, top_p=0.95, max_tokens=5
),
stream=False,
)
responses = [r async for r in grpc_client.Generate(request)]
assert len(responses) == 1
assert responses[0].HasField("complete")
# Test with top_k
request = vllm_engine_pb2.GenerateRequest(
request_id="test-sampling-topk",
tokenized=vllm_engine_pb2.TokenizedInput(
original_text="Hello",
input_ids=[15496],
),
sampling_params=vllm_engine_pb2.SamplingParams(
temperature=1.0, top_k=50, max_tokens=5
),
stream=False,
)
responses = [r async for r in grpc_client.Generate(request)]
assert len(responses) == 1
assert responses[0].HasField("complete")
@pytest.mark.asyncio
async def test_generate_with_stop_strings(grpc_client):
"""Test Generate with stop strings."""
request = vllm_engine_pb2.GenerateRequest(
request_id="test-stop-strings",
tokenized=vllm_engine_pb2.TokenizedInput(
original_text="Hello",
input_ids=[15496],
),
sampling_params=vllm_engine_pb2.SamplingParams(
temperature=0.0,
max_tokens=20,
stop=["\n", "END"],
),
stream=False,
)
responses = [r async for r in grpc_client.Generate(request)]
assert len(responses) == 1
assert responses[0].HasField("complete")
complete = responses[0].complete
assert complete.finish_reason in ["stop", "length"]
@pytest.mark.asyncio
async def test_generate_multiple_requests(grpc_client):
"""Test handling multiple concurrent Generate requests."""
async def make_request(request_id: str):
request = vllm_engine_pb2.GenerateRequest(
request_id=request_id,
tokenized=vllm_engine_pb2.TokenizedInput(
original_text="Hello",
input_ids=[15496],
),
sampling_params=vllm_engine_pb2.SamplingParams(
temperature=0.0, max_tokens=5
),
stream=False,
)
responses = [r async for r in grpc_client.Generate(request)]
return responses[0]
# Send multiple requests concurrently
tasks = [make_request(f"test-concurrent-{i}") for i in range(3)]
responses = await asyncio.gather(*tasks)
# Verify all requests completed successfully
assert len(responses) == 3
for i, response in enumerate(responses):
assert response.HasField("complete")
@pytest.mark.asyncio
async def test_generate_with_seed(grpc_client):
"""Test Generate with a fixed seed for reproducibility."""
def make_request(request_id: str, seed: int):
return vllm_engine_pb2.GenerateRequest(
request_id=request_id,
tokenized=vllm_engine_pb2.TokenizedInput(
original_text="The future of AI is",
input_ids=[464, 2003, 286, 9552, 318],
),
sampling_params=vllm_engine_pb2.SamplingParams(
temperature=1.0, max_tokens=10, seed=seed
),
stream=False,
)
# Make two requests with the same seed
request1 = make_request("test-seed-1", 42)
request2 = make_request("test-seed-2", 42)
response_list1 = [r async for r in grpc_client.Generate(request1)]
response_list2 = [r async for r in grpc_client.Generate(request2)]
# Both should complete successfully
assert len(response_list1) == 1
assert len(response_list2) == 1
assert response_list1[0].HasField("complete")
assert response_list2[0].HasField("complete")
# With the same seed, outputs should be identical
output_ids1 = list(response_list1[0].complete.output_ids)
output_ids2 = list(response_list2[0].complete.output_ids)
assert output_ids1 == output_ids2
@pytest.mark.asyncio
async def test_generate_error_handling(grpc_client):
"""Test error handling in Generate RPC."""
# Request with invalid top_p value (-33)
request = vllm_engine_pb2.GenerateRequest(
request_id="test-error-invalid-topp",
sampling_params=vllm_engine_pb2.SamplingParams(
temperature=0.0, max_tokens=10, top_p=-33
),
stream=False,
)
# Should raise an error response
with pytest.raises(grpc.RpcError) as exc_info:
_ = [r async for r in grpc_client.Generate(request)]
assert exc_info.value.code() == grpc.StatusCode.INVALID_ARGUMENT
assert "top_p must be in (0, 1], got -33.0" in exc_info.value.details()
@pytest.mark.asyncio
async def test_abort_request(grpc_client):
"""Test the out-of-band Abort RPC."""
request_id = "test-abort-1"
# Start a long-running streaming generate request
generate_request = vllm_engine_pb2.GenerateRequest(
request_id=request_id,
tokenized=vllm_engine_pb2.TokenizedInput(
original_text="Hello",
input_ids=[15496],
),
sampling_params=vllm_engine_pb2.SamplingParams(
temperature=0.0,
min_tokens=500,
max_tokens=500, # Request many tokens to ensure it runs long enough
),
stream=True,
)
# Track whether we were aborted
was_aborted = False
received_chunks = 0
async def run_generate():
nonlocal was_aborted, received_chunks
async for response in grpc_client.Generate(generate_request):
if response.HasField("chunk"):
received_chunks += 1
if response.HasField("complete"):
complete = response.complete
was_aborted = complete.finish_reason == "abort"
else:
was_aborted = False
async def abort_after_delay():
# Small delay to ensure generate has started
await asyncio.sleep(0.1)
abort_request = vllm_engine_pb2.AbortRequest(request_ids=[request_id])
await grpc_client.Abort(abort_request)
# Run generate and abort concurrently
await asyncio.gather(run_generate(), abort_after_delay())
# The request should have been aborted (received final chunk with
# "abort" finish reason) and finished early due to the abort.
assert was_aborted and received_chunks < 500, (
"Request should have been aborted before generating all 500 tokens"
)

View File

@@ -51,6 +51,12 @@ class ServeSubcommand(CLISubcommand):
if hasattr(args, "model_tag") and args.model_tag is not None:
args.model = args.model_tag
if getattr(args, "grpc", False):
from vllm.entrypoints.grpc_server import serve_grpc
uvloop.run(serve_grpc(args))
return
if args.headless:
if args.api_server_count is not None and args.api_server_count > 0:
raise ValueError(
@@ -127,6 +133,13 @@ class ServeSubcommand(CLISubcommand):
)
serve_parser = make_arg_parser(serve_parser)
serve_parser.add_argument(
"--grpc",
action="store_true",
default=False,
help="Launch a gRPC server instead of the HTTP OpenAI-compatible "
"server. Requires: pip install vllm[grpc].",
)
serve_parser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(subcmd=self.name)
return serve_parser

457
vllm/entrypoints/grpc_server.py Executable file → Normal file
View File

@@ -5,7 +5,8 @@
"""
vLLM gRPC Server
Starts a gRPC server for vLLM using the VllmEngine protocol.
Starts a gRPC server backed by AsyncLLM, using the VllmEngineServicer
from the smg-grpc-servicer package.
Usage:
python -m vllm.entrypoints.grpc_server --model <model_path>
@@ -22,19 +23,23 @@ import asyncio
import signal
import sys
import time
from collections.abc import AsyncGenerator
import grpc
try:
import grpc
from grpc_reflection.v1alpha import reflection
from smg_grpc_proto import vllm_engine_pb2, vllm_engine_pb2_grpc
from smg_grpc_servicer.vllm.servicer import VllmEngineServicer
except ImportError:
raise ImportError(
"smg-grpc-servicer is required for gRPC mode. "
"Install it with: pip install vllm[grpc]"
) from None
import uvloop
from grpc_reflection.v1alpha import reflection
from vllm import SamplingParams, TextPrompt, TokensPrompt
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.entrypoints.utils import log_version_and_model
from vllm.grpc import vllm_engine_pb2, vllm_engine_pb2_grpc
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.sampling_params import RequestOutputKind, StructuredOutputsParams
from vllm.usage.usage_lib import UsageContext
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.v1.engine.async_llm import AsyncLLM
@@ -43,377 +48,9 @@ from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
class VllmEngineServicer(vllm_engine_pb2_grpc.VllmEngineServicer):
"""
gRPC servicer implementing the VllmEngine service.
Handles 6 RPCs:
- Generate: Streaming text generation
- Embed: Embeddings (TODO)
- HealthCheck: Health probe
- Abort: Cancel requests out-of-band
- GetModelInfo: Model metadata
- GetServerInfo: Server state
"""
def __init__(self, async_llm: AsyncLLM, start_time: float):
"""
Initialize the servicer.
Args:
async_llm: The AsyncLLM instance
start_time: The server start time, in seconds since epoch
"""
self.async_llm = async_llm
self.start_time = start_time
logger.info("VllmEngineServicer initialized")
async def Generate(
self,
request: vllm_engine_pb2.GenerateRequest,
context: grpc.aio.ServicerContext,
) -> AsyncGenerator[vllm_engine_pb2.GenerateResponse, None]:
"""
Handle streaming generation requests.
Args:
request: The GenerateRequest protobuf
context: gRPC context
Yields:
GenerateResponse protobuf messages (streaming)
"""
request_id = request.request_id
logger.debug("Generate request %s received.", request_id)
try:
# Extract tokenized input
if request.WhichOneof("input") == "tokenized":
prompt: TokensPrompt = {
"prompt_token_ids": list(request.tokenized.input_ids)
}
if request.tokenized.original_text:
prompt["prompt"] = request.tokenized.original_text
else:
prompt: TextPrompt = {"prompt": request.text}
# Build sampling params with detokenize=False
sampling_params = self._sampling_params_from_proto(
request.sampling_params, stream=request.stream
)
tokenization_kwargs = self._tokenization_kwargs_from_proto(
request.sampling_params
)
async for output in self.async_llm.generate(
prompt=prompt,
sampling_params=sampling_params,
request_id=request_id,
tokenization_kwargs=tokenization_kwargs,
):
# Convert vLLM output to protobuf
# For streaming, always send chunks
if request.stream:
yield self._chunk_response(output)
# Send complete response when finished
if output.finished:
yield self._complete_response(output)
except ValueError as e:
# Invalid request error (equiv to 400).
await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
except Exception as e:
logger.exception("Error in Generate for request %s", request_id)
await context.abort(grpc.StatusCode.INTERNAL, str(e))
async def Embed(
self,
request: vllm_engine_pb2.EmbedRequest,
context: grpc.aio.ServicerContext,
) -> vllm_engine_pb2.EmbedResponse:
"""
Handle embedding requests.
TODO: Implement in Phase 4
Args:
request: The EmbedRequest protobuf
context: gRPC context
Returns:
EmbedResponse protobuf
"""
logger.warning("Embed RPC not yet implemented")
await context.abort(
grpc.StatusCode.UNIMPLEMENTED, "Embed RPC not yet implemented"
)
async def HealthCheck(
self,
request: vllm_engine_pb2.HealthCheckRequest,
context: grpc.aio.ServicerContext,
) -> vllm_engine_pb2.HealthCheckResponse:
"""
Handle health check requests.
Args:
request: The HealthCheckRequest protobuf
context: gRPC context
Returns:
HealthCheckResponse protobuf
"""
is_healthy = not self.async_llm.errored
message = "Health" if is_healthy else "Engine is not alive"
logger.debug("HealthCheck request: healthy=%s, message=%s", is_healthy, message)
return vllm_engine_pb2.HealthCheckResponse(healthy=is_healthy, message=message)
async def Abort(
self,
request: vllm_engine_pb2.AbortRequest,
context: grpc.aio.ServicerContext,
) -> vllm_engine_pb2.AbortResponse:
"""
Out-of-band abort requests.
Args:
request: The AbortRequest protobuf
context: gRPC context
Returns:
AbortResponse protobuf
"""
request_ids = request.request_ids
logger.debug("Abort requests: %s", request_ids)
await self.async_llm.abort(request_ids)
return vllm_engine_pb2.AbortResponse()
async def GetModelInfo(
self,
request: vllm_engine_pb2.GetModelInfoRequest,
context: grpc.aio.ServicerContext,
) -> vllm_engine_pb2.GetModelInfoResponse:
"""
Handle model info requests.
Args:
request: The GetModelInfoRequest protobuf
context: gRPC context
Returns:
GetModelInfoResponse protobuf
"""
model_config = self.async_llm.model_config
return vllm_engine_pb2.GetModelInfoResponse(
model_path=model_config.model,
is_generation=model_config.runner_type == "generate",
max_context_length=model_config.max_model_len,
vocab_size=model_config.get_vocab_size(),
supports_vision=model_config.is_multimodal_model,
)
async def GetServerInfo(
self,
request: vllm_engine_pb2.GetServerInfoRequest,
context: grpc.aio.ServicerContext,
) -> vllm_engine_pb2.GetServerInfoResponse:
"""
Handle server info requests.
Args:
request: The GetServerInfoRequest protobuf
context: gRPC context
Returns:
GetServerInfoResponse protobuf
"""
num_requests = self.async_llm.output_processor.get_num_unfinished_requests()
return vllm_engine_pb2.GetServerInfoResponse(
active_requests=num_requests,
is_paused=False, # TODO
last_receive_timestamp=time.time(), # TODO looks wrong?
uptime_seconds=time.time() - self.start_time,
server_type="vllm-grpc",
)
# ========== Helper methods ==========
@staticmethod
def _sampling_params_from_proto(
params: vllm_engine_pb2.SamplingParams, stream: bool = True
) -> SamplingParams:
"""
Convert protobuf SamplingParams to vLLM SamplingParams.
Args:
params: Protobuf SamplingParams message
stream: Whether streaming is enabled
Returns:
vLLM SamplingParams with detokenize=False and structured_outputs
"""
# Build stop sequences
stop = list(params.stop) if params.stop else None
stop_token_ids = list(params.stop_token_ids) if params.stop_token_ids else None
# Handle structured outputs constraints
structured_outputs = None
constraint_field = params.WhichOneof("constraint")
if constraint_field:
if constraint_field == "json_schema":
structured_outputs = StructuredOutputsParams(json=params.json_schema)
elif constraint_field == "regex":
structured_outputs = StructuredOutputsParams(regex=params.regex)
elif constraint_field == "grammar":
structured_outputs = StructuredOutputsParams(grammar=params.grammar)
elif constraint_field == "structural_tag":
structured_outputs = StructuredOutputsParams(
structural_tag=params.structural_tag
)
elif constraint_field == "json_object":
structured_outputs = StructuredOutputsParams(
json_object=params.json_object
)
elif constraint_field == "choice":
structured_outputs = StructuredOutputsParams(
choice=list(params.choice.choices)
)
# Create SamplingParams
# output_kind=DELTA: Return only new tokens in each chunk (for streaming)
return SamplingParams(
temperature=params.temperature if params.HasField("temperature") else 1.0,
top_p=params.top_p if params.top_p != 0.0 else 1.0,
top_k=params.top_k,
min_p=params.min_p,
frequency_penalty=params.frequency_penalty,
presence_penalty=params.presence_penalty,
repetition_penalty=params.repetition_penalty
if params.repetition_penalty != 0.0
else 1.0,
max_tokens=params.max_tokens if params.HasField("max_tokens") else None,
min_tokens=params.min_tokens,
stop=stop,
stop_token_ids=stop_token_ids,
skip_special_tokens=params.skip_special_tokens,
spaces_between_special_tokens=params.spaces_between_special_tokens,
ignore_eos=params.ignore_eos,
n=params.n if params.n > 0 else 1,
logprobs=params.logprobs if params.HasField("logprobs") else None,
prompt_logprobs=params.prompt_logprobs
if params.HasField("prompt_logprobs")
else None,
seed=params.seed if params.HasField("seed") else None,
include_stop_str_in_output=params.include_stop_str_in_output,
logit_bias=dict(params.logit_bias) if params.logit_bias else None,
structured_outputs=structured_outputs,
# detokenize must be True if stop strings are used
detokenize=bool(stop),
output_kind=RequestOutputKind.DELTA
if stream
else RequestOutputKind.FINAL_ONLY,
)
@staticmethod
def _tokenization_kwargs_from_proto(
params: vllm_engine_pb2.SamplingParams,
) -> dict[str, int] | None:
if params.HasField("truncate_prompt_tokens"):
return {"truncate_prompt_tokens": params.truncate_prompt_tokens}
return None
@staticmethod
def _chunk_response(output: RequestOutput) -> vllm_engine_pb2.GenerateResponse:
"""
Build a streaming chunk response from vLLM output.
When output_kind=DELTA, vLLM returns only new tokens automatically.
Args:
output: vLLM RequestOutput (with delta tokens when output_kind=DELTA)
Returns:
GenerateResponse with chunk field set
"""
# Get the completion output (first one if n > 1)
completion = output.outputs[0] if output.outputs else None
if completion is None:
# Empty chunk
return vllm_engine_pb2.GenerateResponse(
chunk=vllm_engine_pb2.GenerateStreamChunk(
token_ids=[],
prompt_tokens=0,
completion_tokens=0,
cached_tokens=0,
),
)
# When output_kind=DELTA, completion.token_ids contains only new tokens
# vLLM handles the delta logic internally
# completion_tokens = delta count (client will accumulate)
return vllm_engine_pb2.GenerateResponse(
chunk=vllm_engine_pb2.GenerateStreamChunk(
token_ids=completion.token_ids,
prompt_tokens=len(output.prompt_token_ids)
if output.prompt_token_ids
else 0,
completion_tokens=len(completion.token_ids), # Delta count
cached_tokens=output.num_cached_tokens,
),
)
@staticmethod
def _complete_response(output: RequestOutput) -> vllm_engine_pb2.GenerateResponse:
"""
Build a final completion response from vLLM output.
Args:
output: vLLM RequestOutput (finished=True)
Returns:
GenerateResponse with complete field set
"""
# Get the completion output (first one if n > 1)
completion = output.outputs[0] if output.outputs else None
if completion is None:
# Empty completion
return vllm_engine_pb2.GenerateResponse(
complete=vllm_engine_pb2.GenerateComplete(
output_ids=[],
finish_reason="error",
prompt_tokens=0,
completion_tokens=0,
cached_tokens=0,
),
)
# Build complete response
# When streaming (DELTA mode): completion.token_ids will be empty/last delta
# When non-streaming (FINAL_ONLY mode): completion.token_ids has all tokens
# Client will accumulate token counts for streaming
return vllm_engine_pb2.GenerateResponse(
complete=vllm_engine_pb2.GenerateComplete(
output_ids=completion.token_ids,
finish_reason=completion.finish_reason or "stop",
prompt_tokens=len(output.prompt_token_ids)
if output.prompt_token_ids
else 0,
completion_tokens=len(completion.token_ids),
cached_tokens=output.num_cached_tokens,
),
)
async def serve_grpc(args: argparse.Namespace):
"""
Main serving function.
Main gRPC serving function.
Args:
args: Parsed command line arguments
@@ -428,7 +65,7 @@ async def serve_grpc(args: argparse.Namespace):
# Build vLLM config
vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.OPENAI_API_SERVER
usage_context=UsageContext.OPENAI_API_SERVER,
)
# Create AsyncLLM
@@ -436,7 +73,7 @@ async def serve_grpc(args: argparse.Namespace):
vllm_config=vllm_config,
usage_context=UsageContext.OPENAI_API_SERVER,
enable_log_requests=args.enable_log_requests,
disable_log_stats=args.disable_log_stats_server,
disable_log_stats=args.disable_log_stats,
)
# Create servicer
@@ -447,6 +84,11 @@ async def serve_grpc(args: argparse.Namespace):
options=[
("grpc.max_send_message_length", -1),
("grpc.max_receive_message_length", -1),
# Tolerate client keepalive pings every 10s (default 300s is too
# strict for non-streaming requests where no DATA frames flow
# during generation)
("grpc.http2.min_recv_ping_interval_without_data_ms", 10000),
("grpc.keepalive_permit_without_calls", True),
],
)
@@ -461,46 +103,42 @@ async def serve_grpc(args: argparse.Namespace):
reflection.enable_server_reflection(service_names, server)
# Bind to address
address = f"{args.host}:{args.port}"
host = args.host or "0.0.0.0"
address = f"{host}:{args.port}"
server.add_insecure_port(address)
# Start server
await server.start()
logger.info("vLLM gRPC server started on %s", address)
logger.info("Server is ready to accept requests")
# Handle shutdown signals
loop = asyncio.get_running_loop()
stop_event = asyncio.Event()
def signal_handler():
logger.info("Received shutdown signal")
stop_event.set()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
# Serve until shutdown signal
try:
await stop_event.wait()
except KeyboardInterrupt:
logger.info("Interrupted by user")
# Start server
await server.start()
logger.info("vLLM gRPC server started on %s", address)
logger.info("Server is ready to accept requests")
# Handle shutdown signals
loop = asyncio.get_running_loop()
stop_event = asyncio.Event()
def signal_handler():
logger.info("Received shutdown signal")
stop_event.set()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
try:
await stop_event.wait()
except KeyboardInterrupt:
logger.info("Interrupted by user")
finally:
logger.info("Shutting down vLLM gRPC server...")
# Stop gRPC server
await server.stop(grace=5.0)
logger.info("gRPC server stopped")
# Shutdown AsyncLLM
async_llm.shutdown()
logger.info("AsyncLLM engine stopped")
logger.info("Shutdown complete")
def main():
"""Main entry point."""
"""Main entry point for python -m vllm.entrypoints.grpc_server."""
parser = FlexibleArgumentParser(
description="vLLM gRPC Server",
)
@@ -518,13 +156,6 @@ def main():
default=50051,
help="Port to bind gRPC server to",
)
parser.add_argument(
"--disable-log-stats-server",
action="store_true",
help="Disable stats logging on server side",
)
# Add vLLM engine args
parser = AsyncEngineArgs.add_cli_args(parser)
args = parser.parse_args()

View File

@@ -1,17 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
vLLM gRPC protocol definitions.
This module contains the protocol buffer definitions for vLLM's gRPC API.
The protobuf files are compiled into Python code using grpcio-tools.
"""
# These imports will be available after protobuf compilation
# from vllm.grpc import vllm_engine_pb2
# from vllm.grpc import vllm_engine_pb2_grpc
__all__ = [
"vllm_engine_pb2",
"vllm_engine_pb2_grpc",
]

View File

@@ -1,94 +0,0 @@
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Compile vLLM protobuf definitions into Python code.
This script uses grpcio-tools to generate *_pb2.py, *_pb2_grpc.py, and
*_pb2.pyi (type stubs) files from the vllm_engine.proto definition.
NOTE: Proto compilation happens automatically during package build (via setup.py).
This script is provided for developers who want to regenerate protos manually,
e.g., after modifying vllm_engine.proto.
Usage:
python vllm/grpc/compile_protos.py
Requirements:
pip install grpcio-tools
"""
import sys
from pathlib import Path
def compile_protos():
"""Compile protobuf definitions."""
# Get the vllm package root directory
script_dir = Path(__file__).parent
vllm_package_root = script_dir.parent.parent # vllm/vllm/grpc -> vllm/
proto_file = script_dir / "vllm_engine.proto"
if not proto_file.exists():
print(f"Error: Proto file not found at {proto_file}")
return 1
print(f"Compiling protobuf: {proto_file}")
print(f"Output directory: {script_dir}")
# Compile the proto file
# We use vllm/vllm as the proto_path so that the package is vllm.grpc.engine
try:
from grpc_tools import protoc
result = protoc.main(
[
"grpc_tools.protoc",
f"--proto_path={vllm_package_root}",
f"--python_out={vllm_package_root}",
f"--grpc_python_out={vllm_package_root}",
f"--pyi_out={vllm_package_root}", # Generate type stubs
str(script_dir / "vllm_engine.proto"),
]
)
if result == 0:
# Add SPDX headers to generated files
spdx_header = (
"# SPDX-License-Identifier: Apache-2.0\n"
"# SPDX-FileCopyrightText: Copyright contributors to the vLLM project\n"
)
for generated_file in [
script_dir / "vllm_engine_pb2.py",
script_dir / "vllm_engine_pb2_grpc.py",
script_dir / "vllm_engine_pb2.pyi",
]:
if generated_file.exists():
content = generated_file.read_text()
if not content.startswith("# SPDX-License-Identifier"):
# Add mypy ignore-errors comment for all generated files
header = spdx_header + "# mypy: ignore-errors\n"
generated_file.write_text(header + content)
print("✓ Protobuf compilation successful!")
print(f" Generated: {script_dir / 'vllm_engine_pb2.py'}")
print(f" Generated: {script_dir / 'vllm_engine_pb2_grpc.py'}")
print(f" Generated: {script_dir / 'vllm_engine_pb2.pyi'} (type stubs)")
return 0
else:
print(f"Error: protoc returned {result}")
return result
except ImportError:
print("Error: grpcio-tools not installed")
print("Install with: pip install grpcio-tools")
return 1
except Exception as e:
print(f"Error during compilation: {e}")
return 1
if __name__ == "__main__":
sys.exit(compile_protos())

View File

@@ -1,195 +0,0 @@
syntax = "proto3";
package vllm.grpc.engine;
// Service definition for vLLM engine communication
// This protocol is designed for efficient binary communication between
// the Rust router and vLLM Python engine (AsyncLLM).
service VllmEngine {
// Submit a generation request (supports streaming)
rpc Generate(GenerateRequest) returns (stream GenerateResponse);
// Submit an embedding request
rpc Embed(EmbedRequest) returns (EmbedResponse);
// Health check
rpc HealthCheck(HealthCheckRequest) returns (HealthCheckResponse);
// Abort a running request
rpc Abort(AbortRequest) returns (AbortResponse);
// Get model information
rpc GetModelInfo(GetModelInfoRequest) returns (GetModelInfoResponse);
// Get server information
rpc GetServerInfo(GetServerInfoRequest) returns (GetServerInfoResponse);
}
// =====================
// Common Types
// =====================
// Sampling parameters for text generation
message SamplingParams {
optional float temperature = 1;
float top_p = 2;
uint32 top_k = 3;
float min_p = 4;
float frequency_penalty = 5;
float presence_penalty = 6;
float repetition_penalty = 7;
optional uint32 max_tokens = 8;
uint32 min_tokens = 9;
repeated string stop = 10;
repeated uint32 stop_token_ids = 11;
bool skip_special_tokens = 12;
bool spaces_between_special_tokens = 13;
bool ignore_eos = 14;
uint32 n = 15; // Number of parallel samples
// Logprobs configuration
optional int32 logprobs = 22; // Number of log probabilities per output token (-1 for all)
optional int32 prompt_logprobs = 23; // Number of log probabilities per prompt token (-1 for all)
// Additional vLLM fields
optional int32 seed = 24; // Random seed for reproducibility
bool include_stop_str_in_output = 25; // Whether to include stop strings in output
map<int32, float> logit_bias = 26; // Token ID to bias mapping (-100 to 100)
optional int32 truncate_prompt_tokens = 27; // Prompt truncation (-1 for model max)
// Structured outputs (one of) - matches vLLM's StructuredOutputsParams
oneof constraint {
string json_schema = 16; // JSON schema for structured output
string regex = 17; // Regex pattern
string grammar = 18; // Grammar/EBNF for structured output
string structural_tag = 19; // Structural tag (e.g., Harmony models)
bool json_object = 20; // Force JSON object output
ChoiceConstraint choice = 21; // List of allowed choices
}
}
// Choice constraint for structured outputs
message ChoiceConstraint {
repeated string choices = 1;
}
// Pre-tokenized input from Rust router
message TokenizedInput {
string original_text = 1; // For reference/debugging
repeated uint32 input_ids = 2; // Actual token IDs to process
}
// =====================
// Generate Request
// =====================
message GenerateRequest {
string request_id = 1;
// Prompt input
oneof input {
TokenizedInput tokenized = 2;
string text = 3;
}
// Generation parameters (includes logprobs config)
SamplingParams sampling_params = 4;
// Streaming
bool stream = 5;
}
// =====================
// Generate Response
// =====================
message GenerateResponse {
oneof response {
GenerateStreamChunk chunk = 1; // For streaming
GenerateComplete complete = 2; // For final/non-streaming
}
}
message GenerateStreamChunk {
repeated uint32 token_ids = 1; // Incremental tokens
uint32 prompt_tokens = 2;
uint32 completion_tokens = 3;
uint32 cached_tokens = 4;
// Logprobs support (TODO: implement in Phase 4)
// OutputLogProbs output_logprobs = 5;
// InputLogProbs input_logprobs = 6; // Only in first chunk
}
message GenerateComplete {
repeated uint32 output_ids = 1; // All output tokens
string finish_reason = 2; // "stop", "length", "abort"
uint32 prompt_tokens = 3;
uint32 completion_tokens = 4;
uint32 cached_tokens = 5;
// Logprobs support (TODO: implement in Phase 4)
// OutputLogProbs output_logprobs = 6;
// InputLogProbs input_logprobs = 7;
}
// =====================
// Embedding Request
// =====================
message EmbedRequest {
string request_id = 1;
TokenizedInput tokenized = 2;
}
message EmbedResponse {
repeated float embedding = 1;
uint32 prompt_tokens = 2;
uint32 embedding_dim = 3;
}
// =====================
// Management Operations
// =====================
message HealthCheckRequest {}
message HealthCheckResponse {
bool healthy = 1;
string message = 2;
}
message AbortRequest {
repeated string request_ids = 1;
}
message AbortResponse {
}
// =====================
// Model and Server Info
// =====================
message GetModelInfoRequest {}
message GetModelInfoResponse {
string model_path = 1;
bool is_generation = 2;
uint32 max_context_length = 3;
uint32 vocab_size = 4;
bool supports_vision = 5;
}
message GetServerInfoRequest {}
message GetServerInfoResponse {
uint32 active_requests = 1;
bool is_paused = 2;
double last_receive_timestamp = 3;
double uptime_seconds = 4;
string server_type = 5; // "vllm-grpc"
}