feat(grpc): extract gRPC servicer into smg-grpc-servicer package, add --grpc flag to vllm serve (#36169)
Signed-off-by: Chang Su <chang.s.su@oracle.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -9,7 +9,6 @@ requires = [
|
||||
"torch == 2.10.0",
|
||||
"wheel",
|
||||
"jinja2",
|
||||
"grpcio-tools==1.78.0",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
@@ -57,10 +56,6 @@ include = ["vllm*"]
|
||||
"vllm/third_party/**" = ["ALL"]
|
||||
"vllm/version.py" = ["F401"]
|
||||
"vllm/_version.py" = ["ALL"]
|
||||
# Exclude generated protobuf files
|
||||
"vllm/grpc/*_pb2.py" = ["ALL"]
|
||||
"vllm/grpc/*_pb2_grpc.py" = ["ALL"]
|
||||
"vllm/grpc/*_pb2.pyi" = ["ALL"]
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
|
||||
@@ -10,4 +10,3 @@ jinja2>=3.1.6
|
||||
regex
|
||||
build
|
||||
protobuf >= 5.29.6, !=6.30.*, !=6.31.*, !=6.32.*, !=6.33.0.*, !=6.33.1.*, !=6.33.2.*, !=6.33.3.*, !=6.33.4.*
|
||||
grpcio-tools==1.78.0 # Required for grpc entrypoints
|
||||
|
||||
@@ -51,8 +51,6 @@ openai-harmony >= 0.0.3 # Required for gpt-oss
|
||||
anthropic >= 0.71.0
|
||||
model-hosting-container-standards >= 0.1.13, < 1.0.0
|
||||
mcp
|
||||
grpcio
|
||||
grpcio-reflection
|
||||
opentelemetry-sdk >= 1.27.0
|
||||
opentelemetry-api >= 1.27.0
|
||||
opentelemetry-exporter-otlp >= 1.27.0
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
# The version of gRPC libraries should be consistent with each other
|
||||
grpcio==1.78.0
|
||||
grpcio-reflection==1.78.0
|
||||
grpcio-tools==1.78.0
|
||||
|
||||
numba == 0.61.2 # Required for N-gram speculative decoding
|
||||
|
||||
|
||||
@@ -51,7 +51,6 @@ tritonclient>=2.51.0
|
||||
# The version of gRPC libraries should be consistent with each other
|
||||
grpcio==1.78.0
|
||||
grpcio-reflection==1.78.0
|
||||
grpcio-tools==1.78.0
|
||||
|
||||
arctic-inference == 0.1.1 # Required for suffix decoding test
|
||||
numba == 0.61.2 # Required for N-gram speculative decoding
|
||||
|
||||
@@ -289,13 +289,10 @@ grpcio==1.78.0
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# grpcio-reflection
|
||||
# grpcio-tools
|
||||
# ray
|
||||
# tensorboard
|
||||
grpcio-reflection==1.78.0
|
||||
# via -r requirements/test.in
|
||||
grpcio-tools==1.78.0
|
||||
# via -r requirements/test.in
|
||||
h11==0.14.0
|
||||
# via
|
||||
# httpcore
|
||||
@@ -765,7 +762,6 @@ protobuf==6.33.2
|
||||
# google-api-core
|
||||
# googleapis-common-protos
|
||||
# grpcio-reflection
|
||||
# grpcio-tools
|
||||
# opentelemetry-proto
|
||||
# proto-plus
|
||||
# ray
|
||||
@@ -1045,7 +1041,6 @@ sentry-sdk==2.52.0
|
||||
# via wandb
|
||||
setuptools==77.0.3
|
||||
# via
|
||||
# grpcio-tools
|
||||
# lightning-utilities
|
||||
# pytablewriter
|
||||
# tensorboard
|
||||
|
||||
86
setup.py
86
setup.py
@@ -18,8 +18,6 @@ import torch
|
||||
from packaging.version import Version, parse
|
||||
from setuptools import Extension, setup
|
||||
from setuptools.command.build_ext import build_ext
|
||||
from setuptools.command.build_py import build_py
|
||||
from setuptools.command.develop import develop
|
||||
from setuptools_scm import get_version
|
||||
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
|
||||
|
||||
@@ -81,81 +79,6 @@ def is_freethreaded():
|
||||
return bool(sysconfig.get_config_var("Py_GIL_DISABLED"))
|
||||
|
||||
|
||||
def compile_grpc_protos():
|
||||
"""Compile gRPC protobuf definitions during build.
|
||||
|
||||
This generates *_pb2.py, *_pb2_grpc.py, and *_pb2.pyi files from
|
||||
the vllm_engine.proto definition.
|
||||
"""
|
||||
try:
|
||||
from grpc_tools import protoc
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"grpcio-tools not installed, skipping gRPC proto compilation. "
|
||||
"gRPC server functionality will not be available."
|
||||
)
|
||||
return False
|
||||
|
||||
proto_file = ROOT_DIR / "vllm" / "grpc" / "vllm_engine.proto"
|
||||
if not proto_file.exists():
|
||||
logger.warning("Proto file not found at %s, skipping compilation", proto_file)
|
||||
return False
|
||||
|
||||
logger.info("Compiling gRPC protobuf: %s", proto_file)
|
||||
|
||||
result = protoc.main(
|
||||
[
|
||||
"grpc_tools.protoc",
|
||||
f"--proto_path={ROOT_DIR}",
|
||||
f"--python_out={ROOT_DIR}",
|
||||
f"--grpc_python_out={ROOT_DIR}",
|
||||
f"--pyi_out={ROOT_DIR}",
|
||||
str(proto_file),
|
||||
]
|
||||
)
|
||||
|
||||
if result != 0:
|
||||
logger.error("protoc failed with exit code %s", result)
|
||||
return False
|
||||
|
||||
# Add SPDX headers and mypy ignore to generated files
|
||||
spdx_header = (
|
||||
"# SPDX-License-Identifier: Apache-2.0\n"
|
||||
"# SPDX-FileCopyrightText: Copyright contributors to the vLLM project\n"
|
||||
"# mypy: ignore-errors\n"
|
||||
)
|
||||
|
||||
grpc_dir = ROOT_DIR / "vllm" / "grpc"
|
||||
for generated_file in [
|
||||
grpc_dir / "vllm_engine_pb2.py",
|
||||
grpc_dir / "vllm_engine_pb2_grpc.py",
|
||||
grpc_dir / "vllm_engine_pb2.pyi",
|
||||
]:
|
||||
if generated_file.exists():
|
||||
content = generated_file.read_text()
|
||||
if not content.startswith("# SPDX-License-Identifier"):
|
||||
generated_file.write_text(spdx_header + content)
|
||||
|
||||
logger.info("gRPC protobuf compilation successful")
|
||||
return True
|
||||
|
||||
|
||||
class BuildPyAndGenerateGrpc(build_py):
|
||||
"""Build Python modules and generate gRPC stubs from proto files."""
|
||||
|
||||
def run(self):
|
||||
compile_grpc_protos()
|
||||
super().run()
|
||||
|
||||
|
||||
class DevelopAndGenerateGrpc(develop):
|
||||
"""Develop mode that also generates gRPC stubs from proto files."""
|
||||
|
||||
def run(self):
|
||||
compile_grpc_protos()
|
||||
super().run()
|
||||
|
||||
|
||||
class CMakeExtension(Extension):
|
||||
def __init__(self, name: str, cmake_lists_dir: str = ".", **kwa) -> None:
|
||||
super().__init__(name, sources=[], py_limited_api=not is_freethreaded(), **kwa)
|
||||
@@ -1028,17 +951,12 @@ if _no_device():
|
||||
ext_modules = []
|
||||
|
||||
if not ext_modules:
|
||||
cmdclass = {
|
||||
"build_py": BuildPyAndGenerateGrpc,
|
||||
"develop": DevelopAndGenerateGrpc,
|
||||
}
|
||||
cmdclass = {}
|
||||
else:
|
||||
cmdclass = {
|
||||
"build_ext": precompiled_build_ext
|
||||
if envs.VLLM_USE_PRECOMPILED
|
||||
else cmake_build_ext,
|
||||
"build_py": BuildPyAndGenerateGrpc,
|
||||
"develop": DevelopAndGenerateGrpc,
|
||||
}
|
||||
|
||||
setup(
|
||||
@@ -1064,6 +982,8 @@ setup(
|
||||
"petit-kernel": ["petit-kernel"],
|
||||
# Optional deps for Helion kernel development
|
||||
"helion": ["helion"],
|
||||
# Optional deps for gRPC server (vllm serve --grpc)
|
||||
"grpc": ["smg-grpc-servicer >= 0.4.2"],
|
||||
# Optional deps for OpenTelemetry tracing
|
||||
"otel": [
|
||||
"opentelemetry-sdk>=1.26.0",
|
||||
|
||||
@@ -1,428 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
End-to-end tests for the vLLM gRPC server.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
import grpc
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from vllm.grpc import vllm_engine_pb2, vllm_engine_pb2_grpc
|
||||
|
||||
# Use a small model for fast testing
|
||||
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
|
||||
|
||||
|
||||
def find_free_port() -> int:
|
||||
"""Find a free port on localhost."""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
s.listen(1)
|
||||
port = s.getsockname()[1]
|
||||
return port
|
||||
|
||||
|
||||
async def wait_for_server(port: int, timeout: float = 60.0) -> bool:
|
||||
"""Wait for the gRPC server to be ready by trying health checks."""
|
||||
start_time = time.time()
|
||||
print("waiting for server to start...")
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
channel = grpc.aio.insecure_channel(f"localhost:{port}")
|
||||
stub = vllm_engine_pb2_grpc.VllmEngineStub(channel)
|
||||
request = vllm_engine_pb2.HealthCheckRequest()
|
||||
response = await stub.HealthCheck(request, timeout=5.0)
|
||||
await channel.close()
|
||||
if response.healthy:
|
||||
print("server returned healthy=True")
|
||||
return True
|
||||
except Exception:
|
||||
await asyncio.sleep(0.5)
|
||||
return False
|
||||
|
||||
|
||||
class GrpcServerProcess:
|
||||
"""Manages a gRPC server running in a subprocess."""
|
||||
|
||||
def __init__(self):
|
||||
self.process: subprocess.Popen | None = None
|
||||
self.port: int | None = None
|
||||
|
||||
async def start(self):
|
||||
"""Start the gRPC server process."""
|
||||
self.port = find_free_port()
|
||||
|
||||
# Start the server as a subprocess
|
||||
self.process = subprocess.Popen(
|
||||
[
|
||||
sys.executable,
|
||||
"-m",
|
||||
"vllm.entrypoints.grpc_server",
|
||||
"--model",
|
||||
MODEL_NAME,
|
||||
"--host",
|
||||
"localhost",
|
||||
"--port",
|
||||
str(self.port),
|
||||
"--max-num-batched-tokens",
|
||||
"512",
|
||||
"--disable-log-stats-server",
|
||||
],
|
||||
)
|
||||
|
||||
# Wait for server to be ready
|
||||
if not await wait_for_server(self.port):
|
||||
self.stop()
|
||||
raise RuntimeError("gRPC server failed to start within timeout")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the gRPC server process."""
|
||||
if self.process:
|
||||
self.process.terminate()
|
||||
try:
|
||||
self.process.wait(timeout=10)
|
||||
except subprocess.TimeoutExpired:
|
||||
self.process.kill()
|
||||
self.process.wait()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="module")
|
||||
async def grpc_server():
|
||||
"""Fixture providing a running gRPC server in a subprocess."""
|
||||
server = GrpcServerProcess()
|
||||
await server.start()
|
||||
|
||||
yield server
|
||||
|
||||
server.stop()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def grpc_client(grpc_server):
|
||||
"""Fixture providing a gRPC client connected to the server."""
|
||||
channel = grpc.aio.insecure_channel(f"localhost:{grpc_server.port}")
|
||||
stub = vllm_engine_pb2_grpc.VllmEngineStub(channel)
|
||||
|
||||
yield stub
|
||||
|
||||
await channel.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check(grpc_client):
|
||||
"""Test the HealthCheck RPC."""
|
||||
request = vllm_engine_pb2.HealthCheckRequest()
|
||||
response = await grpc_client.HealthCheck(request)
|
||||
|
||||
assert response.healthy is True
|
||||
assert response.message == "Health"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_model_info(grpc_client):
|
||||
"""Test the GetModelInfo RPC."""
|
||||
request = vllm_engine_pb2.GetModelInfoRequest()
|
||||
response = await grpc_client.GetModelInfo(request)
|
||||
|
||||
assert response.model_path == MODEL_NAME
|
||||
assert response.is_generation is True
|
||||
assert response.max_context_length > 0
|
||||
assert response.vocab_size > 0
|
||||
assert response.supports_vision is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_server_info(grpc_client):
|
||||
"""Test the GetServerInfo RPC."""
|
||||
request = vllm_engine_pb2.GetServerInfoRequest()
|
||||
response = await grpc_client.GetServerInfo(request)
|
||||
|
||||
assert response.active_requests >= 0
|
||||
assert response.is_paused is False
|
||||
assert response.uptime_seconds >= 0
|
||||
assert response.server_type == "vllm-grpc"
|
||||
assert response.last_receive_timestamp > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_non_streaming(grpc_client):
|
||||
"""Test the Generate RPC in non-streaming mode."""
|
||||
# Create a simple request
|
||||
request = vllm_engine_pb2.GenerateRequest(
|
||||
request_id="test-non-streaming-1",
|
||||
tokenized=vllm_engine_pb2.TokenizedInput(
|
||||
original_text="Hello, my name is",
|
||||
input_ids=[15496, 11, 616, 1438, 318], # GPT-2 tokens for the prompt
|
||||
),
|
||||
sampling_params=vllm_engine_pb2.SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=10,
|
||||
n=1,
|
||||
),
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Collect all responses
|
||||
responses = []
|
||||
async for response in grpc_client.Generate(request):
|
||||
responses.append(response)
|
||||
|
||||
# Should have exactly one response (complete)
|
||||
assert len(responses) == 1
|
||||
|
||||
# Check the response
|
||||
final_response = responses[0]
|
||||
assert final_response.HasField("complete")
|
||||
|
||||
complete = final_response.complete
|
||||
assert len(complete.output_ids) > 0
|
||||
assert complete.finish_reason in ["stop", "length"]
|
||||
assert complete.prompt_tokens > 0
|
||||
assert complete.completion_tokens > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_streaming(grpc_client):
|
||||
"""Test the Generate RPC in streaming mode."""
|
||||
request = vllm_engine_pb2.GenerateRequest(
|
||||
request_id="test-streaming-1",
|
||||
tokenized=vllm_engine_pb2.TokenizedInput(
|
||||
original_text="The capital of France is",
|
||||
input_ids=[464, 3139, 286, 4881, 318], # GPT-2 tokens
|
||||
),
|
||||
sampling_params=vllm_engine_pb2.SamplingParams(
|
||||
temperature=0.0, max_tokens=10, n=1
|
||||
),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Collect all responses
|
||||
chunks = []
|
||||
complete_response = None
|
||||
|
||||
async for response in grpc_client.Generate(request):
|
||||
if response.HasField("chunk"):
|
||||
chunks.append(response.chunk)
|
||||
elif response.HasField("complete"):
|
||||
complete_response = response.complete
|
||||
|
||||
# Should have received some chunks
|
||||
assert len(chunks) >= 0 # May have 0 chunks if generation is very fast
|
||||
|
||||
# Should have a final complete response
|
||||
assert complete_response is not None
|
||||
assert complete_response.finish_reason in ["stop", "length"]
|
||||
assert complete_response.prompt_tokens > 0
|
||||
|
||||
# Verify chunk structure
|
||||
for chunk in chunks:
|
||||
assert chunk.prompt_tokens > 0
|
||||
assert chunk.completion_tokens >= 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_with_different_sampling_params(grpc_client):
|
||||
"""Test Generate with various sampling parameters."""
|
||||
# Test with temperature
|
||||
request = vllm_engine_pb2.GenerateRequest(
|
||||
request_id="test-sampling-temp",
|
||||
tokenized=vllm_engine_pb2.TokenizedInput(
|
||||
original_text="Hello",
|
||||
input_ids=[15496],
|
||||
),
|
||||
sampling_params=vllm_engine_pb2.SamplingParams(
|
||||
temperature=0.8, top_p=0.95, max_tokens=5
|
||||
),
|
||||
stream=False,
|
||||
)
|
||||
|
||||
responses = [r async for r in grpc_client.Generate(request)]
|
||||
assert len(responses) == 1
|
||||
assert responses[0].HasField("complete")
|
||||
|
||||
# Test with top_k
|
||||
request = vllm_engine_pb2.GenerateRequest(
|
||||
request_id="test-sampling-topk",
|
||||
tokenized=vllm_engine_pb2.TokenizedInput(
|
||||
original_text="Hello",
|
||||
input_ids=[15496],
|
||||
),
|
||||
sampling_params=vllm_engine_pb2.SamplingParams(
|
||||
temperature=1.0, top_k=50, max_tokens=5
|
||||
),
|
||||
stream=False,
|
||||
)
|
||||
|
||||
responses = [r async for r in grpc_client.Generate(request)]
|
||||
assert len(responses) == 1
|
||||
assert responses[0].HasField("complete")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_with_stop_strings(grpc_client):
|
||||
"""Test Generate with stop strings."""
|
||||
request = vllm_engine_pb2.GenerateRequest(
|
||||
request_id="test-stop-strings",
|
||||
tokenized=vllm_engine_pb2.TokenizedInput(
|
||||
original_text="Hello",
|
||||
input_ids=[15496],
|
||||
),
|
||||
sampling_params=vllm_engine_pb2.SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=20,
|
||||
stop=["\n", "END"],
|
||||
),
|
||||
stream=False,
|
||||
)
|
||||
|
||||
responses = [r async for r in grpc_client.Generate(request)]
|
||||
assert len(responses) == 1
|
||||
assert responses[0].HasField("complete")
|
||||
|
||||
complete = responses[0].complete
|
||||
assert complete.finish_reason in ["stop", "length"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_multiple_requests(grpc_client):
|
||||
"""Test handling multiple concurrent Generate requests."""
|
||||
|
||||
async def make_request(request_id: str):
|
||||
request = vllm_engine_pb2.GenerateRequest(
|
||||
request_id=request_id,
|
||||
tokenized=vllm_engine_pb2.TokenizedInput(
|
||||
original_text="Hello",
|
||||
input_ids=[15496],
|
||||
),
|
||||
sampling_params=vllm_engine_pb2.SamplingParams(
|
||||
temperature=0.0, max_tokens=5
|
||||
),
|
||||
stream=False,
|
||||
)
|
||||
|
||||
responses = [r async for r in grpc_client.Generate(request)]
|
||||
return responses[0]
|
||||
|
||||
# Send multiple requests concurrently
|
||||
tasks = [make_request(f"test-concurrent-{i}") for i in range(3)]
|
||||
responses = await asyncio.gather(*tasks)
|
||||
|
||||
# Verify all requests completed successfully
|
||||
assert len(responses) == 3
|
||||
for i, response in enumerate(responses):
|
||||
assert response.HasField("complete")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_with_seed(grpc_client):
|
||||
"""Test Generate with a fixed seed for reproducibility."""
|
||||
|
||||
def make_request(request_id: str, seed: int):
|
||||
return vllm_engine_pb2.GenerateRequest(
|
||||
request_id=request_id,
|
||||
tokenized=vllm_engine_pb2.TokenizedInput(
|
||||
original_text="The future of AI is",
|
||||
input_ids=[464, 2003, 286, 9552, 318],
|
||||
),
|
||||
sampling_params=vllm_engine_pb2.SamplingParams(
|
||||
temperature=1.0, max_tokens=10, seed=seed
|
||||
),
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Make two requests with the same seed
|
||||
request1 = make_request("test-seed-1", 42)
|
||||
request2 = make_request("test-seed-2", 42)
|
||||
|
||||
response_list1 = [r async for r in grpc_client.Generate(request1)]
|
||||
response_list2 = [r async for r in grpc_client.Generate(request2)]
|
||||
|
||||
# Both should complete successfully
|
||||
assert len(response_list1) == 1
|
||||
assert len(response_list2) == 1
|
||||
assert response_list1[0].HasField("complete")
|
||||
assert response_list2[0].HasField("complete")
|
||||
|
||||
# With the same seed, outputs should be identical
|
||||
output_ids1 = list(response_list1[0].complete.output_ids)
|
||||
output_ids2 = list(response_list2[0].complete.output_ids)
|
||||
assert output_ids1 == output_ids2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_error_handling(grpc_client):
|
||||
"""Test error handling in Generate RPC."""
|
||||
# Request with invalid top_p value (-33)
|
||||
request = vllm_engine_pb2.GenerateRequest(
|
||||
request_id="test-error-invalid-topp",
|
||||
sampling_params=vllm_engine_pb2.SamplingParams(
|
||||
temperature=0.0, max_tokens=10, top_p=-33
|
||||
),
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Should raise an error response
|
||||
with pytest.raises(grpc.RpcError) as exc_info:
|
||||
_ = [r async for r in grpc_client.Generate(request)]
|
||||
|
||||
assert exc_info.value.code() == grpc.StatusCode.INVALID_ARGUMENT
|
||||
assert "top_p must be in (0, 1], got -33.0" in exc_info.value.details()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abort_request(grpc_client):
|
||||
"""Test the out-of-band Abort RPC."""
|
||||
request_id = "test-abort-1"
|
||||
|
||||
# Start a long-running streaming generate request
|
||||
generate_request = vllm_engine_pb2.GenerateRequest(
|
||||
request_id=request_id,
|
||||
tokenized=vllm_engine_pb2.TokenizedInput(
|
||||
original_text="Hello",
|
||||
input_ids=[15496],
|
||||
),
|
||||
sampling_params=vllm_engine_pb2.SamplingParams(
|
||||
temperature=0.0,
|
||||
min_tokens=500,
|
||||
max_tokens=500, # Request many tokens to ensure it runs long enough
|
||||
),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Track whether we were aborted
|
||||
was_aborted = False
|
||||
received_chunks = 0
|
||||
|
||||
async def run_generate():
|
||||
nonlocal was_aborted, received_chunks
|
||||
async for response in grpc_client.Generate(generate_request):
|
||||
if response.HasField("chunk"):
|
||||
received_chunks += 1
|
||||
|
||||
if response.HasField("complete"):
|
||||
complete = response.complete
|
||||
was_aborted = complete.finish_reason == "abort"
|
||||
else:
|
||||
was_aborted = False
|
||||
|
||||
async def abort_after_delay():
|
||||
# Small delay to ensure generate has started
|
||||
await asyncio.sleep(0.1)
|
||||
abort_request = vllm_engine_pb2.AbortRequest(request_ids=[request_id])
|
||||
await grpc_client.Abort(abort_request)
|
||||
|
||||
# Run generate and abort concurrently
|
||||
await asyncio.gather(run_generate(), abort_after_delay())
|
||||
|
||||
# The request should have been aborted (received final chunk with
|
||||
# "abort" finish reason) and finished early due to the abort.
|
||||
assert was_aborted and received_chunks < 500, (
|
||||
"Request should have been aborted before generating all 500 tokens"
|
||||
)
|
||||
@@ -51,6 +51,12 @@ class ServeSubcommand(CLISubcommand):
|
||||
if hasattr(args, "model_tag") and args.model_tag is not None:
|
||||
args.model = args.model_tag
|
||||
|
||||
if getattr(args, "grpc", False):
|
||||
from vllm.entrypoints.grpc_server import serve_grpc
|
||||
|
||||
uvloop.run(serve_grpc(args))
|
||||
return
|
||||
|
||||
if args.headless:
|
||||
if args.api_server_count is not None and args.api_server_count > 0:
|
||||
raise ValueError(
|
||||
@@ -127,6 +133,13 @@ class ServeSubcommand(CLISubcommand):
|
||||
)
|
||||
|
||||
serve_parser = make_arg_parser(serve_parser)
|
||||
serve_parser.add_argument(
|
||||
"--grpc",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Launch a gRPC server instead of the HTTP OpenAI-compatible "
|
||||
"server. Requires: pip install vllm[grpc].",
|
||||
)
|
||||
serve_parser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(subcmd=self.name)
|
||||
return serve_parser
|
||||
|
||||
|
||||
457
vllm/entrypoints/grpc_server.py
Executable file → Normal file
457
vllm/entrypoints/grpc_server.py
Executable file → Normal file
@@ -5,7 +5,8 @@
|
||||
"""
|
||||
vLLM gRPC Server
|
||||
|
||||
Starts a gRPC server for vLLM using the VllmEngine protocol.
|
||||
Starts a gRPC server backed by AsyncLLM, using the VllmEngineServicer
|
||||
from the smg-grpc-servicer package.
|
||||
|
||||
Usage:
|
||||
python -m vllm.entrypoints.grpc_server --model <model_path>
|
||||
@@ -22,19 +23,23 @@ import asyncio
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
import grpc
|
||||
try:
|
||||
import grpc
|
||||
from grpc_reflection.v1alpha import reflection
|
||||
from smg_grpc_proto import vllm_engine_pb2, vllm_engine_pb2_grpc
|
||||
from smg_grpc_servicer.vllm.servicer import VllmEngineServicer
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"smg-grpc-servicer is required for gRPC mode. "
|
||||
"Install it with: pip install vllm[grpc]"
|
||||
) from None
|
||||
|
||||
import uvloop
|
||||
from grpc_reflection.v1alpha import reflection
|
||||
|
||||
from vllm import SamplingParams, TextPrompt, TokensPrompt
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.entrypoints.utils import log_version_and_model
|
||||
from vllm.grpc import vllm_engine_pb2, vllm_engine_pb2_grpc
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import RequestOutputKind, StructuredOutputsParams
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
@@ -43,377 +48,9 @@ from vllm.version import __version__ as VLLM_VERSION
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class VllmEngineServicer(vllm_engine_pb2_grpc.VllmEngineServicer):
|
||||
"""
|
||||
gRPC servicer implementing the VllmEngine service.
|
||||
|
||||
Handles 6 RPCs:
|
||||
- Generate: Streaming text generation
|
||||
- Embed: Embeddings (TODO)
|
||||
- HealthCheck: Health probe
|
||||
- Abort: Cancel requests out-of-band
|
||||
- GetModelInfo: Model metadata
|
||||
- GetServerInfo: Server state
|
||||
"""
|
||||
|
||||
def __init__(self, async_llm: AsyncLLM, start_time: float):
|
||||
"""
|
||||
Initialize the servicer.
|
||||
|
||||
Args:
|
||||
async_llm: The AsyncLLM instance
|
||||
start_time: The server start time, in seconds since epoch
|
||||
"""
|
||||
self.async_llm = async_llm
|
||||
self.start_time = start_time
|
||||
logger.info("VllmEngineServicer initialized")
|
||||
|
||||
async def Generate(
|
||||
self,
|
||||
request: vllm_engine_pb2.GenerateRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> AsyncGenerator[vllm_engine_pb2.GenerateResponse, None]:
|
||||
"""
|
||||
Handle streaming generation requests.
|
||||
|
||||
Args:
|
||||
request: The GenerateRequest protobuf
|
||||
context: gRPC context
|
||||
|
||||
Yields:
|
||||
GenerateResponse protobuf messages (streaming)
|
||||
"""
|
||||
request_id = request.request_id
|
||||
logger.debug("Generate request %s received.", request_id)
|
||||
|
||||
try:
|
||||
# Extract tokenized input
|
||||
if request.WhichOneof("input") == "tokenized":
|
||||
prompt: TokensPrompt = {
|
||||
"prompt_token_ids": list(request.tokenized.input_ids)
|
||||
}
|
||||
if request.tokenized.original_text:
|
||||
prompt["prompt"] = request.tokenized.original_text
|
||||
else:
|
||||
prompt: TextPrompt = {"prompt": request.text}
|
||||
|
||||
# Build sampling params with detokenize=False
|
||||
sampling_params = self._sampling_params_from_proto(
|
||||
request.sampling_params, stream=request.stream
|
||||
)
|
||||
tokenization_kwargs = self._tokenization_kwargs_from_proto(
|
||||
request.sampling_params
|
||||
)
|
||||
|
||||
async for output in self.async_llm.generate(
|
||||
prompt=prompt,
|
||||
sampling_params=sampling_params,
|
||||
request_id=request_id,
|
||||
tokenization_kwargs=tokenization_kwargs,
|
||||
):
|
||||
# Convert vLLM output to protobuf
|
||||
# For streaming, always send chunks
|
||||
if request.stream:
|
||||
yield self._chunk_response(output)
|
||||
|
||||
# Send complete response when finished
|
||||
if output.finished:
|
||||
yield self._complete_response(output)
|
||||
|
||||
except ValueError as e:
|
||||
# Invalid request error (equiv to 400).
|
||||
await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Error in Generate for request %s", request_id)
|
||||
await context.abort(grpc.StatusCode.INTERNAL, str(e))
|
||||
|
||||
async def Embed(
|
||||
self,
|
||||
request: vllm_engine_pb2.EmbedRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> vllm_engine_pb2.EmbedResponse:
|
||||
"""
|
||||
Handle embedding requests.
|
||||
|
||||
TODO: Implement in Phase 4
|
||||
|
||||
Args:
|
||||
request: The EmbedRequest protobuf
|
||||
context: gRPC context
|
||||
|
||||
Returns:
|
||||
EmbedResponse protobuf
|
||||
"""
|
||||
logger.warning("Embed RPC not yet implemented")
|
||||
await context.abort(
|
||||
grpc.StatusCode.UNIMPLEMENTED, "Embed RPC not yet implemented"
|
||||
)
|
||||
|
||||
async def HealthCheck(
|
||||
self,
|
||||
request: vllm_engine_pb2.HealthCheckRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> vllm_engine_pb2.HealthCheckResponse:
|
||||
"""
|
||||
Handle health check requests.
|
||||
|
||||
Args:
|
||||
request: The HealthCheckRequest protobuf
|
||||
context: gRPC context
|
||||
|
||||
Returns:
|
||||
HealthCheckResponse protobuf
|
||||
"""
|
||||
is_healthy = not self.async_llm.errored
|
||||
message = "Health" if is_healthy else "Engine is not alive"
|
||||
|
||||
logger.debug("HealthCheck request: healthy=%s, message=%s", is_healthy, message)
|
||||
|
||||
return vllm_engine_pb2.HealthCheckResponse(healthy=is_healthy, message=message)
|
||||
|
||||
async def Abort(
|
||||
self,
|
||||
request: vllm_engine_pb2.AbortRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> vllm_engine_pb2.AbortResponse:
|
||||
"""
|
||||
Out-of-band abort requests.
|
||||
|
||||
Args:
|
||||
request: The AbortRequest protobuf
|
||||
context: gRPC context
|
||||
|
||||
Returns:
|
||||
AbortResponse protobuf
|
||||
"""
|
||||
request_ids = request.request_ids
|
||||
logger.debug("Abort requests: %s", request_ids)
|
||||
|
||||
await self.async_llm.abort(request_ids)
|
||||
return vllm_engine_pb2.AbortResponse()
|
||||
|
||||
async def GetModelInfo(
|
||||
self,
|
||||
request: vllm_engine_pb2.GetModelInfoRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> vllm_engine_pb2.GetModelInfoResponse:
|
||||
"""
|
||||
Handle model info requests.
|
||||
|
||||
Args:
|
||||
request: The GetModelInfoRequest protobuf
|
||||
context: gRPC context
|
||||
|
||||
Returns:
|
||||
GetModelInfoResponse protobuf
|
||||
"""
|
||||
model_config = self.async_llm.model_config
|
||||
|
||||
return vllm_engine_pb2.GetModelInfoResponse(
|
||||
model_path=model_config.model,
|
||||
is_generation=model_config.runner_type == "generate",
|
||||
max_context_length=model_config.max_model_len,
|
||||
vocab_size=model_config.get_vocab_size(),
|
||||
supports_vision=model_config.is_multimodal_model,
|
||||
)
|
||||
|
||||
async def GetServerInfo(
|
||||
self,
|
||||
request: vllm_engine_pb2.GetServerInfoRequest,
|
||||
context: grpc.aio.ServicerContext,
|
||||
) -> vllm_engine_pb2.GetServerInfoResponse:
|
||||
"""
|
||||
Handle server info requests.
|
||||
|
||||
Args:
|
||||
request: The GetServerInfoRequest protobuf
|
||||
context: gRPC context
|
||||
|
||||
Returns:
|
||||
GetServerInfoResponse protobuf
|
||||
"""
|
||||
num_requests = self.async_llm.output_processor.get_num_unfinished_requests()
|
||||
|
||||
return vllm_engine_pb2.GetServerInfoResponse(
|
||||
active_requests=num_requests,
|
||||
is_paused=False, # TODO
|
||||
last_receive_timestamp=time.time(), # TODO looks wrong?
|
||||
uptime_seconds=time.time() - self.start_time,
|
||||
server_type="vllm-grpc",
|
||||
)
|
||||
|
||||
# ========== Helper methods ==========
|
||||
|
||||
@staticmethod
|
||||
def _sampling_params_from_proto(
|
||||
params: vllm_engine_pb2.SamplingParams, stream: bool = True
|
||||
) -> SamplingParams:
|
||||
"""
|
||||
Convert protobuf SamplingParams to vLLM SamplingParams.
|
||||
|
||||
Args:
|
||||
params: Protobuf SamplingParams message
|
||||
stream: Whether streaming is enabled
|
||||
|
||||
Returns:
|
||||
vLLM SamplingParams with detokenize=False and structured_outputs
|
||||
"""
|
||||
# Build stop sequences
|
||||
stop = list(params.stop) if params.stop else None
|
||||
stop_token_ids = list(params.stop_token_ids) if params.stop_token_ids else None
|
||||
|
||||
# Handle structured outputs constraints
|
||||
structured_outputs = None
|
||||
constraint_field = params.WhichOneof("constraint")
|
||||
if constraint_field:
|
||||
if constraint_field == "json_schema":
|
||||
structured_outputs = StructuredOutputsParams(json=params.json_schema)
|
||||
elif constraint_field == "regex":
|
||||
structured_outputs = StructuredOutputsParams(regex=params.regex)
|
||||
elif constraint_field == "grammar":
|
||||
structured_outputs = StructuredOutputsParams(grammar=params.grammar)
|
||||
elif constraint_field == "structural_tag":
|
||||
structured_outputs = StructuredOutputsParams(
|
||||
structural_tag=params.structural_tag
|
||||
)
|
||||
elif constraint_field == "json_object":
|
||||
structured_outputs = StructuredOutputsParams(
|
||||
json_object=params.json_object
|
||||
)
|
||||
elif constraint_field == "choice":
|
||||
structured_outputs = StructuredOutputsParams(
|
||||
choice=list(params.choice.choices)
|
||||
)
|
||||
|
||||
# Create SamplingParams
|
||||
# output_kind=DELTA: Return only new tokens in each chunk (for streaming)
|
||||
return SamplingParams(
|
||||
temperature=params.temperature if params.HasField("temperature") else 1.0,
|
||||
top_p=params.top_p if params.top_p != 0.0 else 1.0,
|
||||
top_k=params.top_k,
|
||||
min_p=params.min_p,
|
||||
frequency_penalty=params.frequency_penalty,
|
||||
presence_penalty=params.presence_penalty,
|
||||
repetition_penalty=params.repetition_penalty
|
||||
if params.repetition_penalty != 0.0
|
||||
else 1.0,
|
||||
max_tokens=params.max_tokens if params.HasField("max_tokens") else None,
|
||||
min_tokens=params.min_tokens,
|
||||
stop=stop,
|
||||
stop_token_ids=stop_token_ids,
|
||||
skip_special_tokens=params.skip_special_tokens,
|
||||
spaces_between_special_tokens=params.spaces_between_special_tokens,
|
||||
ignore_eos=params.ignore_eos,
|
||||
n=params.n if params.n > 0 else 1,
|
||||
logprobs=params.logprobs if params.HasField("logprobs") else None,
|
||||
prompt_logprobs=params.prompt_logprobs
|
||||
if params.HasField("prompt_logprobs")
|
||||
else None,
|
||||
seed=params.seed if params.HasField("seed") else None,
|
||||
include_stop_str_in_output=params.include_stop_str_in_output,
|
||||
logit_bias=dict(params.logit_bias) if params.logit_bias else None,
|
||||
structured_outputs=structured_outputs,
|
||||
# detokenize must be True if stop strings are used
|
||||
detokenize=bool(stop),
|
||||
output_kind=RequestOutputKind.DELTA
|
||||
if stream
|
||||
else RequestOutputKind.FINAL_ONLY,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _tokenization_kwargs_from_proto(
|
||||
params: vllm_engine_pb2.SamplingParams,
|
||||
) -> dict[str, int] | None:
|
||||
if params.HasField("truncate_prompt_tokens"):
|
||||
return {"truncate_prompt_tokens": params.truncate_prompt_tokens}
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _chunk_response(output: RequestOutput) -> vllm_engine_pb2.GenerateResponse:
|
||||
"""
|
||||
Build a streaming chunk response from vLLM output.
|
||||
When output_kind=DELTA, vLLM returns only new tokens automatically.
|
||||
|
||||
Args:
|
||||
output: vLLM RequestOutput (with delta tokens when output_kind=DELTA)
|
||||
|
||||
Returns:
|
||||
GenerateResponse with chunk field set
|
||||
"""
|
||||
# Get the completion output (first one if n > 1)
|
||||
completion = output.outputs[0] if output.outputs else None
|
||||
|
||||
if completion is None:
|
||||
# Empty chunk
|
||||
return vllm_engine_pb2.GenerateResponse(
|
||||
chunk=vllm_engine_pb2.GenerateStreamChunk(
|
||||
token_ids=[],
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
cached_tokens=0,
|
||||
),
|
||||
)
|
||||
|
||||
# When output_kind=DELTA, completion.token_ids contains only new tokens
|
||||
# vLLM handles the delta logic internally
|
||||
# completion_tokens = delta count (client will accumulate)
|
||||
return vllm_engine_pb2.GenerateResponse(
|
||||
chunk=vllm_engine_pb2.GenerateStreamChunk(
|
||||
token_ids=completion.token_ids,
|
||||
prompt_tokens=len(output.prompt_token_ids)
|
||||
if output.prompt_token_ids
|
||||
else 0,
|
||||
completion_tokens=len(completion.token_ids), # Delta count
|
||||
cached_tokens=output.num_cached_tokens,
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _complete_response(output: RequestOutput) -> vllm_engine_pb2.GenerateResponse:
|
||||
"""
|
||||
Build a final completion response from vLLM output.
|
||||
|
||||
Args:
|
||||
output: vLLM RequestOutput (finished=True)
|
||||
|
||||
Returns:
|
||||
GenerateResponse with complete field set
|
||||
"""
|
||||
# Get the completion output (first one if n > 1)
|
||||
completion = output.outputs[0] if output.outputs else None
|
||||
|
||||
if completion is None:
|
||||
# Empty completion
|
||||
return vllm_engine_pb2.GenerateResponse(
|
||||
complete=vllm_engine_pb2.GenerateComplete(
|
||||
output_ids=[],
|
||||
finish_reason="error",
|
||||
prompt_tokens=0,
|
||||
completion_tokens=0,
|
||||
cached_tokens=0,
|
||||
),
|
||||
)
|
||||
|
||||
# Build complete response
|
||||
# When streaming (DELTA mode): completion.token_ids will be empty/last delta
|
||||
# When non-streaming (FINAL_ONLY mode): completion.token_ids has all tokens
|
||||
# Client will accumulate token counts for streaming
|
||||
return vllm_engine_pb2.GenerateResponse(
|
||||
complete=vllm_engine_pb2.GenerateComplete(
|
||||
output_ids=completion.token_ids,
|
||||
finish_reason=completion.finish_reason or "stop",
|
||||
prompt_tokens=len(output.prompt_token_ids)
|
||||
if output.prompt_token_ids
|
||||
else 0,
|
||||
completion_tokens=len(completion.token_ids),
|
||||
cached_tokens=output.num_cached_tokens,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def serve_grpc(args: argparse.Namespace):
|
||||
"""
|
||||
Main serving function.
|
||||
Main gRPC serving function.
|
||||
|
||||
Args:
|
||||
args: Parsed command line arguments
|
||||
@@ -428,7 +65,7 @@ async def serve_grpc(args: argparse.Namespace):
|
||||
|
||||
# Build vLLM config
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
usage_context=UsageContext.OPENAI_API_SERVER
|
||||
usage_context=UsageContext.OPENAI_API_SERVER,
|
||||
)
|
||||
|
||||
# Create AsyncLLM
|
||||
@@ -436,7 +73,7 @@ async def serve_grpc(args: argparse.Namespace):
|
||||
vllm_config=vllm_config,
|
||||
usage_context=UsageContext.OPENAI_API_SERVER,
|
||||
enable_log_requests=args.enable_log_requests,
|
||||
disable_log_stats=args.disable_log_stats_server,
|
||||
disable_log_stats=args.disable_log_stats,
|
||||
)
|
||||
|
||||
# Create servicer
|
||||
@@ -447,6 +84,11 @@ async def serve_grpc(args: argparse.Namespace):
|
||||
options=[
|
||||
("grpc.max_send_message_length", -1),
|
||||
("grpc.max_receive_message_length", -1),
|
||||
# Tolerate client keepalive pings every 10s (default 300s is too
|
||||
# strict for non-streaming requests where no DATA frames flow
|
||||
# during generation)
|
||||
("grpc.http2.min_recv_ping_interval_without_data_ms", 10000),
|
||||
("grpc.keepalive_permit_without_calls", True),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -461,46 +103,42 @@ async def serve_grpc(args: argparse.Namespace):
|
||||
reflection.enable_server_reflection(service_names, server)
|
||||
|
||||
# Bind to address
|
||||
address = f"{args.host}:{args.port}"
|
||||
host = args.host or "0.0.0.0"
|
||||
address = f"{host}:{args.port}"
|
||||
server.add_insecure_port(address)
|
||||
|
||||
# Start server
|
||||
await server.start()
|
||||
logger.info("vLLM gRPC server started on %s", address)
|
||||
logger.info("Server is ready to accept requests")
|
||||
|
||||
# Handle shutdown signals
|
||||
loop = asyncio.get_running_loop()
|
||||
stop_event = asyncio.Event()
|
||||
|
||||
def signal_handler():
|
||||
logger.info("Received shutdown signal")
|
||||
stop_event.set()
|
||||
|
||||
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||
loop.add_signal_handler(sig, signal_handler)
|
||||
|
||||
# Serve until shutdown signal
|
||||
try:
|
||||
await stop_event.wait()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Interrupted by user")
|
||||
# Start server
|
||||
await server.start()
|
||||
logger.info("vLLM gRPC server started on %s", address)
|
||||
logger.info("Server is ready to accept requests")
|
||||
|
||||
# Handle shutdown signals
|
||||
loop = asyncio.get_running_loop()
|
||||
stop_event = asyncio.Event()
|
||||
|
||||
def signal_handler():
|
||||
logger.info("Received shutdown signal")
|
||||
stop_event.set()
|
||||
|
||||
for sig in (signal.SIGTERM, signal.SIGINT):
|
||||
loop.add_signal_handler(sig, signal_handler)
|
||||
|
||||
try:
|
||||
await stop_event.wait()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Interrupted by user")
|
||||
finally:
|
||||
logger.info("Shutting down vLLM gRPC server...")
|
||||
|
||||
# Stop gRPC server
|
||||
await server.stop(grace=5.0)
|
||||
logger.info("gRPC server stopped")
|
||||
|
||||
# Shutdown AsyncLLM
|
||||
async_llm.shutdown()
|
||||
logger.info("AsyncLLM engine stopped")
|
||||
|
||||
logger.info("Shutdown complete")
|
||||
|
||||
|
||||
def main():
|
||||
"""Main entry point."""
|
||||
"""Main entry point for python -m vllm.entrypoints.grpc_server."""
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM gRPC Server",
|
||||
)
|
||||
@@ -518,13 +156,6 @@ def main():
|
||||
default=50051,
|
||||
help="Port to bind gRPC server to",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--disable-log-stats-server",
|
||||
action="store_true",
|
||||
help="Disable stats logging on server side",
|
||||
)
|
||||
|
||||
# Add vLLM engine args
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -1,17 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
vLLM gRPC protocol definitions.
|
||||
|
||||
This module contains the protocol buffer definitions for vLLM's gRPC API.
|
||||
The protobuf files are compiled into Python code using grpcio-tools.
|
||||
"""
|
||||
|
||||
# These imports will be available after protobuf compilation
|
||||
# from vllm.grpc import vllm_engine_pb2
|
||||
# from vllm.grpc import vllm_engine_pb2_grpc
|
||||
|
||||
__all__ = [
|
||||
"vllm_engine_pb2",
|
||||
"vllm_engine_pb2_grpc",
|
||||
]
|
||||
@@ -1,94 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Compile vLLM protobuf definitions into Python code.
|
||||
|
||||
This script uses grpcio-tools to generate *_pb2.py, *_pb2_grpc.py, and
|
||||
*_pb2.pyi (type stubs) files from the vllm_engine.proto definition.
|
||||
|
||||
NOTE: Proto compilation happens automatically during package build (via setup.py).
|
||||
This script is provided for developers who want to regenerate protos manually,
|
||||
e.g., after modifying vllm_engine.proto.
|
||||
|
||||
Usage:
|
||||
python vllm/grpc/compile_protos.py
|
||||
|
||||
Requirements:
|
||||
pip install grpcio-tools
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def compile_protos():
|
||||
"""Compile protobuf definitions."""
|
||||
# Get the vllm package root directory
|
||||
script_dir = Path(__file__).parent
|
||||
vllm_package_root = script_dir.parent.parent # vllm/vllm/grpc -> vllm/
|
||||
|
||||
proto_file = script_dir / "vllm_engine.proto"
|
||||
|
||||
if not proto_file.exists():
|
||||
print(f"Error: Proto file not found at {proto_file}")
|
||||
return 1
|
||||
|
||||
print(f"Compiling protobuf: {proto_file}")
|
||||
print(f"Output directory: {script_dir}")
|
||||
|
||||
# Compile the proto file
|
||||
# We use vllm/vllm as the proto_path so that the package is vllm.grpc.engine
|
||||
try:
|
||||
from grpc_tools import protoc
|
||||
|
||||
result = protoc.main(
|
||||
[
|
||||
"grpc_tools.protoc",
|
||||
f"--proto_path={vllm_package_root}",
|
||||
f"--python_out={vllm_package_root}",
|
||||
f"--grpc_python_out={vllm_package_root}",
|
||||
f"--pyi_out={vllm_package_root}", # Generate type stubs
|
||||
str(script_dir / "vllm_engine.proto"),
|
||||
]
|
||||
)
|
||||
|
||||
if result == 0:
|
||||
# Add SPDX headers to generated files
|
||||
spdx_header = (
|
||||
"# SPDX-License-Identifier: Apache-2.0\n"
|
||||
"# SPDX-FileCopyrightText: Copyright contributors to the vLLM project\n"
|
||||
)
|
||||
|
||||
for generated_file in [
|
||||
script_dir / "vllm_engine_pb2.py",
|
||||
script_dir / "vllm_engine_pb2_grpc.py",
|
||||
script_dir / "vllm_engine_pb2.pyi",
|
||||
]:
|
||||
if generated_file.exists():
|
||||
content = generated_file.read_text()
|
||||
if not content.startswith("# SPDX-License-Identifier"):
|
||||
# Add mypy ignore-errors comment for all generated files
|
||||
header = spdx_header + "# mypy: ignore-errors\n"
|
||||
generated_file.write_text(header + content)
|
||||
|
||||
print("✓ Protobuf compilation successful!")
|
||||
print(f" Generated: {script_dir / 'vllm_engine_pb2.py'}")
|
||||
print(f" Generated: {script_dir / 'vllm_engine_pb2_grpc.py'}")
|
||||
print(f" Generated: {script_dir / 'vllm_engine_pb2.pyi'} (type stubs)")
|
||||
return 0
|
||||
else:
|
||||
print(f"Error: protoc returned {result}")
|
||||
return result
|
||||
|
||||
except ImportError:
|
||||
print("Error: grpcio-tools not installed")
|
||||
print("Install with: pip install grpcio-tools")
|
||||
return 1
|
||||
except Exception as e:
|
||||
print(f"Error during compilation: {e}")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(compile_protos())
|
||||
@@ -1,195 +0,0 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package vllm.grpc.engine;
|
||||
|
||||
// Service definition for vLLM engine communication
|
||||
// This protocol is designed for efficient binary communication between
|
||||
// the Rust router and vLLM Python engine (AsyncLLM).
|
||||
service VllmEngine {
|
||||
// Submit a generation request (supports streaming)
|
||||
rpc Generate(GenerateRequest) returns (stream GenerateResponse);
|
||||
|
||||
// Submit an embedding request
|
||||
rpc Embed(EmbedRequest) returns (EmbedResponse);
|
||||
|
||||
// Health check
|
||||
rpc HealthCheck(HealthCheckRequest) returns (HealthCheckResponse);
|
||||
|
||||
// Abort a running request
|
||||
rpc Abort(AbortRequest) returns (AbortResponse);
|
||||
|
||||
// Get model information
|
||||
rpc GetModelInfo(GetModelInfoRequest) returns (GetModelInfoResponse);
|
||||
|
||||
// Get server information
|
||||
rpc GetServerInfo(GetServerInfoRequest) returns (GetServerInfoResponse);
|
||||
}
|
||||
|
||||
// =====================
|
||||
// Common Types
|
||||
// =====================
|
||||
|
||||
// Sampling parameters for text generation
|
||||
message SamplingParams {
|
||||
optional float temperature = 1;
|
||||
float top_p = 2;
|
||||
uint32 top_k = 3;
|
||||
float min_p = 4;
|
||||
float frequency_penalty = 5;
|
||||
float presence_penalty = 6;
|
||||
float repetition_penalty = 7;
|
||||
|
||||
optional uint32 max_tokens = 8;
|
||||
uint32 min_tokens = 9;
|
||||
|
||||
repeated string stop = 10;
|
||||
repeated uint32 stop_token_ids = 11;
|
||||
|
||||
bool skip_special_tokens = 12;
|
||||
bool spaces_between_special_tokens = 13;
|
||||
bool ignore_eos = 14;
|
||||
|
||||
uint32 n = 15; // Number of parallel samples
|
||||
|
||||
// Logprobs configuration
|
||||
optional int32 logprobs = 22; // Number of log probabilities per output token (-1 for all)
|
||||
optional int32 prompt_logprobs = 23; // Number of log probabilities per prompt token (-1 for all)
|
||||
|
||||
// Additional vLLM fields
|
||||
optional int32 seed = 24; // Random seed for reproducibility
|
||||
bool include_stop_str_in_output = 25; // Whether to include stop strings in output
|
||||
map<int32, float> logit_bias = 26; // Token ID to bias mapping (-100 to 100)
|
||||
optional int32 truncate_prompt_tokens = 27; // Prompt truncation (-1 for model max)
|
||||
|
||||
// Structured outputs (one of) - matches vLLM's StructuredOutputsParams
|
||||
oneof constraint {
|
||||
string json_schema = 16; // JSON schema for structured output
|
||||
string regex = 17; // Regex pattern
|
||||
string grammar = 18; // Grammar/EBNF for structured output
|
||||
string structural_tag = 19; // Structural tag (e.g., Harmony models)
|
||||
bool json_object = 20; // Force JSON object output
|
||||
ChoiceConstraint choice = 21; // List of allowed choices
|
||||
}
|
||||
}
|
||||
|
||||
// Choice constraint for structured outputs
|
||||
message ChoiceConstraint {
|
||||
repeated string choices = 1;
|
||||
}
|
||||
|
||||
// Pre-tokenized input from Rust router
|
||||
message TokenizedInput {
|
||||
string original_text = 1; // For reference/debugging
|
||||
repeated uint32 input_ids = 2; // Actual token IDs to process
|
||||
}
|
||||
|
||||
// =====================
|
||||
// Generate Request
|
||||
// =====================
|
||||
|
||||
message GenerateRequest {
|
||||
string request_id = 1;
|
||||
|
||||
// Prompt input
|
||||
oneof input {
|
||||
TokenizedInput tokenized = 2;
|
||||
string text = 3;
|
||||
}
|
||||
|
||||
// Generation parameters (includes logprobs config)
|
||||
SamplingParams sampling_params = 4;
|
||||
|
||||
// Streaming
|
||||
bool stream = 5;
|
||||
}
|
||||
|
||||
// =====================
|
||||
// Generate Response
|
||||
// =====================
|
||||
|
||||
message GenerateResponse {
|
||||
oneof response {
|
||||
GenerateStreamChunk chunk = 1; // For streaming
|
||||
GenerateComplete complete = 2; // For final/non-streaming
|
||||
}
|
||||
}
|
||||
|
||||
message GenerateStreamChunk {
|
||||
repeated uint32 token_ids = 1; // Incremental tokens
|
||||
uint32 prompt_tokens = 2;
|
||||
uint32 completion_tokens = 3;
|
||||
uint32 cached_tokens = 4;
|
||||
|
||||
// Logprobs support (TODO: implement in Phase 4)
|
||||
// OutputLogProbs output_logprobs = 5;
|
||||
// InputLogProbs input_logprobs = 6; // Only in first chunk
|
||||
}
|
||||
|
||||
message GenerateComplete {
|
||||
repeated uint32 output_ids = 1; // All output tokens
|
||||
string finish_reason = 2; // "stop", "length", "abort"
|
||||
uint32 prompt_tokens = 3;
|
||||
uint32 completion_tokens = 4;
|
||||
uint32 cached_tokens = 5;
|
||||
|
||||
// Logprobs support (TODO: implement in Phase 4)
|
||||
// OutputLogProbs output_logprobs = 6;
|
||||
// InputLogProbs input_logprobs = 7;
|
||||
}
|
||||
|
||||
// =====================
|
||||
// Embedding Request
|
||||
// =====================
|
||||
|
||||
message EmbedRequest {
|
||||
string request_id = 1;
|
||||
TokenizedInput tokenized = 2;
|
||||
}
|
||||
|
||||
message EmbedResponse {
|
||||
repeated float embedding = 1;
|
||||
uint32 prompt_tokens = 2;
|
||||
uint32 embedding_dim = 3;
|
||||
}
|
||||
|
||||
// =====================
|
||||
// Management Operations
|
||||
// =====================
|
||||
|
||||
message HealthCheckRequest {}
|
||||
|
||||
message HealthCheckResponse {
|
||||
bool healthy = 1;
|
||||
string message = 2;
|
||||
}
|
||||
|
||||
message AbortRequest {
|
||||
repeated string request_ids = 1;
|
||||
}
|
||||
|
||||
message AbortResponse {
|
||||
}
|
||||
|
||||
// =====================
|
||||
// Model and Server Info
|
||||
// =====================
|
||||
|
||||
message GetModelInfoRequest {}
|
||||
|
||||
message GetModelInfoResponse {
|
||||
string model_path = 1;
|
||||
bool is_generation = 2;
|
||||
uint32 max_context_length = 3;
|
||||
uint32 vocab_size = 4;
|
||||
bool supports_vision = 5;
|
||||
}
|
||||
|
||||
message GetServerInfoRequest {}
|
||||
|
||||
message GetServerInfoResponse {
|
||||
uint32 active_requests = 1;
|
||||
bool is_paused = 2;
|
||||
double last_receive_timestamp = 3;
|
||||
double uptime_seconds = 4;
|
||||
string server_type = 5; // "vllm-grpc"
|
||||
}
|
||||
Reference in New Issue
Block a user