diff --git a/pyproject.toml b/pyproject.toml index ad2a96db3..07d46f0ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,6 @@ requires = [ "torch == 2.10.0", "wheel", "jinja2", - "grpcio-tools==1.78.0", ] build-backend = "setuptools.build_meta" @@ -57,10 +56,6 @@ include = ["vllm*"] "vllm/third_party/**" = ["ALL"] "vllm/version.py" = ["F401"] "vllm/_version.py" = ["ALL"] -# Exclude generated protobuf files -"vllm/grpc/*_pb2.py" = ["ALL"] -"vllm/grpc/*_pb2_grpc.py" = ["ALL"] -"vllm/grpc/*_pb2.pyi" = ["ALL"] [tool.ruff.lint] select = [ diff --git a/requirements/build.txt b/requirements/build.txt index 6c6c9fc8a..c46880a05 100644 --- a/requirements/build.txt +++ b/requirements/build.txt @@ -10,4 +10,3 @@ jinja2>=3.1.6 regex build protobuf >= 5.29.6, !=6.30.*, !=6.31.*, !=6.32.*, !=6.33.0.*, !=6.33.1.*, !=6.33.2.*, !=6.33.3.*, !=6.33.4.* -grpcio-tools==1.78.0 # Required for grpc entrypoints diff --git a/requirements/common.txt b/requirements/common.txt index b9ea8cd2c..5e156edb7 100644 --- a/requirements/common.txt +++ b/requirements/common.txt @@ -51,8 +51,6 @@ openai-harmony >= 0.0.3 # Required for gpt-oss anthropic >= 0.71.0 model-hosting-container-standards >= 0.1.13, < 1.0.0 mcp -grpcio -grpcio-reflection opentelemetry-sdk >= 1.27.0 opentelemetry-api >= 1.27.0 opentelemetry-exporter-otlp >= 1.27.0 diff --git a/requirements/rocm.txt b/requirements/rocm.txt index a46a1b574..d70083338 100644 --- a/requirements/rocm.txt +++ b/requirements/rocm.txt @@ -4,7 +4,6 @@ # The version of gRPC libraries should be consistent with each other grpcio==1.78.0 grpcio-reflection==1.78.0 -grpcio-tools==1.78.0 numba == 0.61.2 # Required for N-gram speculative decoding diff --git a/requirements/test.in b/requirements/test.in index a551a4c05..85c477c02 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -51,7 +51,6 @@ tritonclient>=2.51.0 # The version of gRPC libraries should be consistent with each other grpcio==1.78.0 grpcio-reflection==1.78.0 -grpcio-tools==1.78.0 arctic-inference == 0.1.1 # Required for suffix decoding test numba == 0.61.2 # Required for N-gram speculative decoding diff --git a/requirements/test.txt b/requirements/test.txt index aacb8fbff..167abb530 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -289,13 +289,10 @@ grpcio==1.78.0 # via # -r requirements/test.in # grpcio-reflection - # grpcio-tools # ray # tensorboard grpcio-reflection==1.78.0 # via -r requirements/test.in -grpcio-tools==1.78.0 - # via -r requirements/test.in h11==0.14.0 # via # httpcore @@ -765,7 +762,6 @@ protobuf==6.33.2 # google-api-core # googleapis-common-protos # grpcio-reflection - # grpcio-tools # opentelemetry-proto # proto-plus # ray @@ -1045,7 +1041,6 @@ sentry-sdk==2.52.0 # via wandb setuptools==77.0.3 # via - # grpcio-tools # lightning-utilities # pytablewriter # tensorboard diff --git a/setup.py b/setup.py index f31b4cf24..691234b3a 100644 --- a/setup.py +++ b/setup.py @@ -18,8 +18,6 @@ import torch from packaging.version import Version, parse from setuptools import Extension, setup from setuptools.command.build_ext import build_ext -from setuptools.command.build_py import build_py -from setuptools.command.develop import develop from setuptools_scm import get_version from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME @@ -81,81 +79,6 @@ def is_freethreaded(): return bool(sysconfig.get_config_var("Py_GIL_DISABLED")) -def compile_grpc_protos(): - """Compile gRPC protobuf definitions during build. - - This generates *_pb2.py, *_pb2_grpc.py, and *_pb2.pyi files from - the vllm_engine.proto definition. - """ - try: - from grpc_tools import protoc - except ImportError: - logger.warning( - "grpcio-tools not installed, skipping gRPC proto compilation. " - "gRPC server functionality will not be available." - ) - return False - - proto_file = ROOT_DIR / "vllm" / "grpc" / "vllm_engine.proto" - if not proto_file.exists(): - logger.warning("Proto file not found at %s, skipping compilation", proto_file) - return False - - logger.info("Compiling gRPC protobuf: %s", proto_file) - - result = protoc.main( - [ - "grpc_tools.protoc", - f"--proto_path={ROOT_DIR}", - f"--python_out={ROOT_DIR}", - f"--grpc_python_out={ROOT_DIR}", - f"--pyi_out={ROOT_DIR}", - str(proto_file), - ] - ) - - if result != 0: - logger.error("protoc failed with exit code %s", result) - return False - - # Add SPDX headers and mypy ignore to generated files - spdx_header = ( - "# SPDX-License-Identifier: Apache-2.0\n" - "# SPDX-FileCopyrightText: Copyright contributors to the vLLM project\n" - "# mypy: ignore-errors\n" - ) - - grpc_dir = ROOT_DIR / "vllm" / "grpc" - for generated_file in [ - grpc_dir / "vllm_engine_pb2.py", - grpc_dir / "vllm_engine_pb2_grpc.py", - grpc_dir / "vllm_engine_pb2.pyi", - ]: - if generated_file.exists(): - content = generated_file.read_text() - if not content.startswith("# SPDX-License-Identifier"): - generated_file.write_text(spdx_header + content) - - logger.info("gRPC protobuf compilation successful") - return True - - -class BuildPyAndGenerateGrpc(build_py): - """Build Python modules and generate gRPC stubs from proto files.""" - - def run(self): - compile_grpc_protos() - super().run() - - -class DevelopAndGenerateGrpc(develop): - """Develop mode that also generates gRPC stubs from proto files.""" - - def run(self): - compile_grpc_protos() - super().run() - - class CMakeExtension(Extension): def __init__(self, name: str, cmake_lists_dir: str = ".", **kwa) -> None: super().__init__(name, sources=[], py_limited_api=not is_freethreaded(), **kwa) @@ -1028,17 +951,12 @@ if _no_device(): ext_modules = [] if not ext_modules: - cmdclass = { - "build_py": BuildPyAndGenerateGrpc, - "develop": DevelopAndGenerateGrpc, - } + cmdclass = {} else: cmdclass = { "build_ext": precompiled_build_ext if envs.VLLM_USE_PRECOMPILED else cmake_build_ext, - "build_py": BuildPyAndGenerateGrpc, - "develop": DevelopAndGenerateGrpc, } setup( @@ -1064,6 +982,8 @@ setup( "petit-kernel": ["petit-kernel"], # Optional deps for Helion kernel development "helion": ["helion"], + # Optional deps for gRPC server (vllm serve --grpc) + "grpc": ["smg-grpc-servicer >= 0.4.2"], # Optional deps for OpenTelemetry tracing "otel": [ "opentelemetry-sdk>=1.26.0", diff --git a/tests/entrypoints/test_grpc_server.py b/tests/entrypoints/test_grpc_server.py deleted file mode 100644 index a4e3a3860..000000000 --- a/tests/entrypoints/test_grpc_server.py +++ /dev/null @@ -1,428 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -End-to-end tests for the vLLM gRPC server. -""" - -import asyncio -import socket -import subprocess -import sys -import time - -import grpc -import pytest -import pytest_asyncio - -from vllm.grpc import vllm_engine_pb2, vllm_engine_pb2_grpc - -# Use a small model for fast testing -MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM" - - -def find_free_port() -> int: - """Find a free port on localhost.""" - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(("", 0)) - s.listen(1) - port = s.getsockname()[1] - return port - - -async def wait_for_server(port: int, timeout: float = 60.0) -> bool: - """Wait for the gRPC server to be ready by trying health checks.""" - start_time = time.time() - print("waiting for server to start...") - while time.time() - start_time < timeout: - try: - channel = grpc.aio.insecure_channel(f"localhost:{port}") - stub = vllm_engine_pb2_grpc.VllmEngineStub(channel) - request = vllm_engine_pb2.HealthCheckRequest() - response = await stub.HealthCheck(request, timeout=5.0) - await channel.close() - if response.healthy: - print("server returned healthy=True") - return True - except Exception: - await asyncio.sleep(0.5) - return False - - -class GrpcServerProcess: - """Manages a gRPC server running in a subprocess.""" - - def __init__(self): - self.process: subprocess.Popen | None = None - self.port: int | None = None - - async def start(self): - """Start the gRPC server process.""" - self.port = find_free_port() - - # Start the server as a subprocess - self.process = subprocess.Popen( - [ - sys.executable, - "-m", - "vllm.entrypoints.grpc_server", - "--model", - MODEL_NAME, - "--host", - "localhost", - "--port", - str(self.port), - "--max-num-batched-tokens", - "512", - "--disable-log-stats-server", - ], - ) - - # Wait for server to be ready - if not await wait_for_server(self.port): - self.stop() - raise RuntimeError("gRPC server failed to start within timeout") - - def stop(self): - """Stop the gRPC server process.""" - if self.process: - self.process.terminate() - try: - self.process.wait(timeout=10) - except subprocess.TimeoutExpired: - self.process.kill() - self.process.wait() - - -@pytest_asyncio.fixture(scope="module") -async def grpc_server(): - """Fixture providing a running gRPC server in a subprocess.""" - server = GrpcServerProcess() - await server.start() - - yield server - - server.stop() - - -@pytest_asyncio.fixture -async def grpc_client(grpc_server): - """Fixture providing a gRPC client connected to the server.""" - channel = grpc.aio.insecure_channel(f"localhost:{grpc_server.port}") - stub = vllm_engine_pb2_grpc.VllmEngineStub(channel) - - yield stub - - await channel.close() - - -@pytest.mark.asyncio -async def test_health_check(grpc_client): - """Test the HealthCheck RPC.""" - request = vllm_engine_pb2.HealthCheckRequest() - response = await grpc_client.HealthCheck(request) - - assert response.healthy is True - assert response.message == "Health" - - -@pytest.mark.asyncio -async def test_get_model_info(grpc_client): - """Test the GetModelInfo RPC.""" - request = vllm_engine_pb2.GetModelInfoRequest() - response = await grpc_client.GetModelInfo(request) - - assert response.model_path == MODEL_NAME - assert response.is_generation is True - assert response.max_context_length > 0 - assert response.vocab_size > 0 - assert response.supports_vision is False - - -@pytest.mark.asyncio -async def test_get_server_info(grpc_client): - """Test the GetServerInfo RPC.""" - request = vllm_engine_pb2.GetServerInfoRequest() - response = await grpc_client.GetServerInfo(request) - - assert response.active_requests >= 0 - assert response.is_paused is False - assert response.uptime_seconds >= 0 - assert response.server_type == "vllm-grpc" - assert response.last_receive_timestamp > 0 - - -@pytest.mark.asyncio -async def test_generate_non_streaming(grpc_client): - """Test the Generate RPC in non-streaming mode.""" - # Create a simple request - request = vllm_engine_pb2.GenerateRequest( - request_id="test-non-streaming-1", - tokenized=vllm_engine_pb2.TokenizedInput( - original_text="Hello, my name is", - input_ids=[15496, 11, 616, 1438, 318], # GPT-2 tokens for the prompt - ), - sampling_params=vllm_engine_pb2.SamplingParams( - temperature=0.0, - max_tokens=10, - n=1, - ), - stream=False, - ) - - # Collect all responses - responses = [] - async for response in grpc_client.Generate(request): - responses.append(response) - - # Should have exactly one response (complete) - assert len(responses) == 1 - - # Check the response - final_response = responses[0] - assert final_response.HasField("complete") - - complete = final_response.complete - assert len(complete.output_ids) > 0 - assert complete.finish_reason in ["stop", "length"] - assert complete.prompt_tokens > 0 - assert complete.completion_tokens > 0 - - -@pytest.mark.asyncio -async def test_generate_streaming(grpc_client): - """Test the Generate RPC in streaming mode.""" - request = vllm_engine_pb2.GenerateRequest( - request_id="test-streaming-1", - tokenized=vllm_engine_pb2.TokenizedInput( - original_text="The capital of France is", - input_ids=[464, 3139, 286, 4881, 318], # GPT-2 tokens - ), - sampling_params=vllm_engine_pb2.SamplingParams( - temperature=0.0, max_tokens=10, n=1 - ), - stream=True, - ) - - # Collect all responses - chunks = [] - complete_response = None - - async for response in grpc_client.Generate(request): - if response.HasField("chunk"): - chunks.append(response.chunk) - elif response.HasField("complete"): - complete_response = response.complete - - # Should have received some chunks - assert len(chunks) >= 0 # May have 0 chunks if generation is very fast - - # Should have a final complete response - assert complete_response is not None - assert complete_response.finish_reason in ["stop", "length"] - assert complete_response.prompt_tokens > 0 - - # Verify chunk structure - for chunk in chunks: - assert chunk.prompt_tokens > 0 - assert chunk.completion_tokens >= 0 - - -@pytest.mark.asyncio -async def test_generate_with_different_sampling_params(grpc_client): - """Test Generate with various sampling parameters.""" - # Test with temperature - request = vllm_engine_pb2.GenerateRequest( - request_id="test-sampling-temp", - tokenized=vllm_engine_pb2.TokenizedInput( - original_text="Hello", - input_ids=[15496], - ), - sampling_params=vllm_engine_pb2.SamplingParams( - temperature=0.8, top_p=0.95, max_tokens=5 - ), - stream=False, - ) - - responses = [r async for r in grpc_client.Generate(request)] - assert len(responses) == 1 - assert responses[0].HasField("complete") - - # Test with top_k - request = vllm_engine_pb2.GenerateRequest( - request_id="test-sampling-topk", - tokenized=vllm_engine_pb2.TokenizedInput( - original_text="Hello", - input_ids=[15496], - ), - sampling_params=vllm_engine_pb2.SamplingParams( - temperature=1.0, top_k=50, max_tokens=5 - ), - stream=False, - ) - - responses = [r async for r in grpc_client.Generate(request)] - assert len(responses) == 1 - assert responses[0].HasField("complete") - - -@pytest.mark.asyncio -async def test_generate_with_stop_strings(grpc_client): - """Test Generate with stop strings.""" - request = vllm_engine_pb2.GenerateRequest( - request_id="test-stop-strings", - tokenized=vllm_engine_pb2.TokenizedInput( - original_text="Hello", - input_ids=[15496], - ), - sampling_params=vllm_engine_pb2.SamplingParams( - temperature=0.0, - max_tokens=20, - stop=["\n", "END"], - ), - stream=False, - ) - - responses = [r async for r in grpc_client.Generate(request)] - assert len(responses) == 1 - assert responses[0].HasField("complete") - - complete = responses[0].complete - assert complete.finish_reason in ["stop", "length"] - - -@pytest.mark.asyncio -async def test_generate_multiple_requests(grpc_client): - """Test handling multiple concurrent Generate requests.""" - - async def make_request(request_id: str): - request = vllm_engine_pb2.GenerateRequest( - request_id=request_id, - tokenized=vllm_engine_pb2.TokenizedInput( - original_text="Hello", - input_ids=[15496], - ), - sampling_params=vllm_engine_pb2.SamplingParams( - temperature=0.0, max_tokens=5 - ), - stream=False, - ) - - responses = [r async for r in grpc_client.Generate(request)] - return responses[0] - - # Send multiple requests concurrently - tasks = [make_request(f"test-concurrent-{i}") for i in range(3)] - responses = await asyncio.gather(*tasks) - - # Verify all requests completed successfully - assert len(responses) == 3 - for i, response in enumerate(responses): - assert response.HasField("complete") - - -@pytest.mark.asyncio -async def test_generate_with_seed(grpc_client): - """Test Generate with a fixed seed for reproducibility.""" - - def make_request(request_id: str, seed: int): - return vllm_engine_pb2.GenerateRequest( - request_id=request_id, - tokenized=vllm_engine_pb2.TokenizedInput( - original_text="The future of AI is", - input_ids=[464, 2003, 286, 9552, 318], - ), - sampling_params=vllm_engine_pb2.SamplingParams( - temperature=1.0, max_tokens=10, seed=seed - ), - stream=False, - ) - - # Make two requests with the same seed - request1 = make_request("test-seed-1", 42) - request2 = make_request("test-seed-2", 42) - - response_list1 = [r async for r in grpc_client.Generate(request1)] - response_list2 = [r async for r in grpc_client.Generate(request2)] - - # Both should complete successfully - assert len(response_list1) == 1 - assert len(response_list2) == 1 - assert response_list1[0].HasField("complete") - assert response_list2[0].HasField("complete") - - # With the same seed, outputs should be identical - output_ids1 = list(response_list1[0].complete.output_ids) - output_ids2 = list(response_list2[0].complete.output_ids) - assert output_ids1 == output_ids2 - - -@pytest.mark.asyncio -async def test_generate_error_handling(grpc_client): - """Test error handling in Generate RPC.""" - # Request with invalid top_p value (-33) - request = vllm_engine_pb2.GenerateRequest( - request_id="test-error-invalid-topp", - sampling_params=vllm_engine_pb2.SamplingParams( - temperature=0.0, max_tokens=10, top_p=-33 - ), - stream=False, - ) - - # Should raise an error response - with pytest.raises(grpc.RpcError) as exc_info: - _ = [r async for r in grpc_client.Generate(request)] - - assert exc_info.value.code() == grpc.StatusCode.INVALID_ARGUMENT - assert "top_p must be in (0, 1], got -33.0" in exc_info.value.details() - - -@pytest.mark.asyncio -async def test_abort_request(grpc_client): - """Test the out-of-band Abort RPC.""" - request_id = "test-abort-1" - - # Start a long-running streaming generate request - generate_request = vllm_engine_pb2.GenerateRequest( - request_id=request_id, - tokenized=vllm_engine_pb2.TokenizedInput( - original_text="Hello", - input_ids=[15496], - ), - sampling_params=vllm_engine_pb2.SamplingParams( - temperature=0.0, - min_tokens=500, - max_tokens=500, # Request many tokens to ensure it runs long enough - ), - stream=True, - ) - - # Track whether we were aborted - was_aborted = False - received_chunks = 0 - - async def run_generate(): - nonlocal was_aborted, received_chunks - async for response in grpc_client.Generate(generate_request): - if response.HasField("chunk"): - received_chunks += 1 - - if response.HasField("complete"): - complete = response.complete - was_aborted = complete.finish_reason == "abort" - else: - was_aborted = False - - async def abort_after_delay(): - # Small delay to ensure generate has started - await asyncio.sleep(0.1) - abort_request = vllm_engine_pb2.AbortRequest(request_ids=[request_id]) - await grpc_client.Abort(abort_request) - - # Run generate and abort concurrently - await asyncio.gather(run_generate(), abort_after_delay()) - - # The request should have been aborted (received final chunk with - # "abort" finish reason) and finished early due to the abort. - assert was_aborted and received_chunks < 500, ( - "Request should have been aborted before generating all 500 tokens" - ) diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 677c6ea0f..dab3a26db 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -51,6 +51,12 @@ class ServeSubcommand(CLISubcommand): if hasattr(args, "model_tag") and args.model_tag is not None: args.model = args.model_tag + if getattr(args, "grpc", False): + from vllm.entrypoints.grpc_server import serve_grpc + + uvloop.run(serve_grpc(args)) + return + if args.headless: if args.api_server_count is not None and args.api_server_count > 0: raise ValueError( @@ -127,6 +133,13 @@ class ServeSubcommand(CLISubcommand): ) serve_parser = make_arg_parser(serve_parser) + serve_parser.add_argument( + "--grpc", + action="store_true", + default=False, + help="Launch a gRPC server instead of the HTTP OpenAI-compatible " + "server. Requires: pip install vllm[grpc].", + ) serve_parser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(subcmd=self.name) return serve_parser diff --git a/vllm/entrypoints/grpc_server.py b/vllm/entrypoints/grpc_server.py old mode 100755 new mode 100644 index ec8f4804b..5bb8ea1b4 --- a/vllm/entrypoints/grpc_server.py +++ b/vllm/entrypoints/grpc_server.py @@ -5,7 +5,8 @@ """ vLLM gRPC Server -Starts a gRPC server for vLLM using the VllmEngine protocol. +Starts a gRPC server backed by AsyncLLM, using the VllmEngineServicer +from the smg-grpc-servicer package. Usage: python -m vllm.entrypoints.grpc_server --model @@ -22,19 +23,23 @@ import asyncio import signal import sys import time -from collections.abc import AsyncGenerator -import grpc +try: + import grpc + from grpc_reflection.v1alpha import reflection + from smg_grpc_proto import vllm_engine_pb2, vllm_engine_pb2_grpc + from smg_grpc_servicer.vllm.servicer import VllmEngineServicer +except ImportError: + raise ImportError( + "smg-grpc-servicer is required for gRPC mode. " + "Install it with: pip install vllm[grpc]" + ) from None + import uvloop -from grpc_reflection.v1alpha import reflection -from vllm import SamplingParams, TextPrompt, TokensPrompt from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.utils import log_version_and_model -from vllm.grpc import vllm_engine_pb2, vllm_engine_pb2_grpc from vllm.logger import init_logger -from vllm.outputs import RequestOutput -from vllm.sampling_params import RequestOutputKind, StructuredOutputsParams from vllm.usage.usage_lib import UsageContext from vllm.utils.argparse_utils import FlexibleArgumentParser from vllm.v1.engine.async_llm import AsyncLLM @@ -43,377 +48,9 @@ from vllm.version import __version__ as VLLM_VERSION logger = init_logger(__name__) -class VllmEngineServicer(vllm_engine_pb2_grpc.VllmEngineServicer): - """ - gRPC servicer implementing the VllmEngine service. - - Handles 6 RPCs: - - Generate: Streaming text generation - - Embed: Embeddings (TODO) - - HealthCheck: Health probe - - Abort: Cancel requests out-of-band - - GetModelInfo: Model metadata - - GetServerInfo: Server state - """ - - def __init__(self, async_llm: AsyncLLM, start_time: float): - """ - Initialize the servicer. - - Args: - async_llm: The AsyncLLM instance - start_time: The server start time, in seconds since epoch - """ - self.async_llm = async_llm - self.start_time = start_time - logger.info("VllmEngineServicer initialized") - - async def Generate( - self, - request: vllm_engine_pb2.GenerateRequest, - context: grpc.aio.ServicerContext, - ) -> AsyncGenerator[vllm_engine_pb2.GenerateResponse, None]: - """ - Handle streaming generation requests. - - Args: - request: The GenerateRequest protobuf - context: gRPC context - - Yields: - GenerateResponse protobuf messages (streaming) - """ - request_id = request.request_id - logger.debug("Generate request %s received.", request_id) - - try: - # Extract tokenized input - if request.WhichOneof("input") == "tokenized": - prompt: TokensPrompt = { - "prompt_token_ids": list(request.tokenized.input_ids) - } - if request.tokenized.original_text: - prompt["prompt"] = request.tokenized.original_text - else: - prompt: TextPrompt = {"prompt": request.text} - - # Build sampling params with detokenize=False - sampling_params = self._sampling_params_from_proto( - request.sampling_params, stream=request.stream - ) - tokenization_kwargs = self._tokenization_kwargs_from_proto( - request.sampling_params - ) - - async for output in self.async_llm.generate( - prompt=prompt, - sampling_params=sampling_params, - request_id=request_id, - tokenization_kwargs=tokenization_kwargs, - ): - # Convert vLLM output to protobuf - # For streaming, always send chunks - if request.stream: - yield self._chunk_response(output) - - # Send complete response when finished - if output.finished: - yield self._complete_response(output) - - except ValueError as e: - # Invalid request error (equiv to 400). - await context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e)) - except Exception as e: - logger.exception("Error in Generate for request %s", request_id) - await context.abort(grpc.StatusCode.INTERNAL, str(e)) - - async def Embed( - self, - request: vllm_engine_pb2.EmbedRequest, - context: grpc.aio.ServicerContext, - ) -> vllm_engine_pb2.EmbedResponse: - """ - Handle embedding requests. - - TODO: Implement in Phase 4 - - Args: - request: The EmbedRequest protobuf - context: gRPC context - - Returns: - EmbedResponse protobuf - """ - logger.warning("Embed RPC not yet implemented") - await context.abort( - grpc.StatusCode.UNIMPLEMENTED, "Embed RPC not yet implemented" - ) - - async def HealthCheck( - self, - request: vllm_engine_pb2.HealthCheckRequest, - context: grpc.aio.ServicerContext, - ) -> vllm_engine_pb2.HealthCheckResponse: - """ - Handle health check requests. - - Args: - request: The HealthCheckRequest protobuf - context: gRPC context - - Returns: - HealthCheckResponse protobuf - """ - is_healthy = not self.async_llm.errored - message = "Health" if is_healthy else "Engine is not alive" - - logger.debug("HealthCheck request: healthy=%s, message=%s", is_healthy, message) - - return vllm_engine_pb2.HealthCheckResponse(healthy=is_healthy, message=message) - - async def Abort( - self, - request: vllm_engine_pb2.AbortRequest, - context: grpc.aio.ServicerContext, - ) -> vllm_engine_pb2.AbortResponse: - """ - Out-of-band abort requests. - - Args: - request: The AbortRequest protobuf - context: gRPC context - - Returns: - AbortResponse protobuf - """ - request_ids = request.request_ids - logger.debug("Abort requests: %s", request_ids) - - await self.async_llm.abort(request_ids) - return vllm_engine_pb2.AbortResponse() - - async def GetModelInfo( - self, - request: vllm_engine_pb2.GetModelInfoRequest, - context: grpc.aio.ServicerContext, - ) -> vllm_engine_pb2.GetModelInfoResponse: - """ - Handle model info requests. - - Args: - request: The GetModelInfoRequest protobuf - context: gRPC context - - Returns: - GetModelInfoResponse protobuf - """ - model_config = self.async_llm.model_config - - return vllm_engine_pb2.GetModelInfoResponse( - model_path=model_config.model, - is_generation=model_config.runner_type == "generate", - max_context_length=model_config.max_model_len, - vocab_size=model_config.get_vocab_size(), - supports_vision=model_config.is_multimodal_model, - ) - - async def GetServerInfo( - self, - request: vllm_engine_pb2.GetServerInfoRequest, - context: grpc.aio.ServicerContext, - ) -> vllm_engine_pb2.GetServerInfoResponse: - """ - Handle server info requests. - - Args: - request: The GetServerInfoRequest protobuf - context: gRPC context - - Returns: - GetServerInfoResponse protobuf - """ - num_requests = self.async_llm.output_processor.get_num_unfinished_requests() - - return vllm_engine_pb2.GetServerInfoResponse( - active_requests=num_requests, - is_paused=False, # TODO - last_receive_timestamp=time.time(), # TODO looks wrong? - uptime_seconds=time.time() - self.start_time, - server_type="vllm-grpc", - ) - - # ========== Helper methods ========== - - @staticmethod - def _sampling_params_from_proto( - params: vllm_engine_pb2.SamplingParams, stream: bool = True - ) -> SamplingParams: - """ - Convert protobuf SamplingParams to vLLM SamplingParams. - - Args: - params: Protobuf SamplingParams message - stream: Whether streaming is enabled - - Returns: - vLLM SamplingParams with detokenize=False and structured_outputs - """ - # Build stop sequences - stop = list(params.stop) if params.stop else None - stop_token_ids = list(params.stop_token_ids) if params.stop_token_ids else None - - # Handle structured outputs constraints - structured_outputs = None - constraint_field = params.WhichOneof("constraint") - if constraint_field: - if constraint_field == "json_schema": - structured_outputs = StructuredOutputsParams(json=params.json_schema) - elif constraint_field == "regex": - structured_outputs = StructuredOutputsParams(regex=params.regex) - elif constraint_field == "grammar": - structured_outputs = StructuredOutputsParams(grammar=params.grammar) - elif constraint_field == "structural_tag": - structured_outputs = StructuredOutputsParams( - structural_tag=params.structural_tag - ) - elif constraint_field == "json_object": - structured_outputs = StructuredOutputsParams( - json_object=params.json_object - ) - elif constraint_field == "choice": - structured_outputs = StructuredOutputsParams( - choice=list(params.choice.choices) - ) - - # Create SamplingParams - # output_kind=DELTA: Return only new tokens in each chunk (for streaming) - return SamplingParams( - temperature=params.temperature if params.HasField("temperature") else 1.0, - top_p=params.top_p if params.top_p != 0.0 else 1.0, - top_k=params.top_k, - min_p=params.min_p, - frequency_penalty=params.frequency_penalty, - presence_penalty=params.presence_penalty, - repetition_penalty=params.repetition_penalty - if params.repetition_penalty != 0.0 - else 1.0, - max_tokens=params.max_tokens if params.HasField("max_tokens") else None, - min_tokens=params.min_tokens, - stop=stop, - stop_token_ids=stop_token_ids, - skip_special_tokens=params.skip_special_tokens, - spaces_between_special_tokens=params.spaces_between_special_tokens, - ignore_eos=params.ignore_eos, - n=params.n if params.n > 0 else 1, - logprobs=params.logprobs if params.HasField("logprobs") else None, - prompt_logprobs=params.prompt_logprobs - if params.HasField("prompt_logprobs") - else None, - seed=params.seed if params.HasField("seed") else None, - include_stop_str_in_output=params.include_stop_str_in_output, - logit_bias=dict(params.logit_bias) if params.logit_bias else None, - structured_outputs=structured_outputs, - # detokenize must be True if stop strings are used - detokenize=bool(stop), - output_kind=RequestOutputKind.DELTA - if stream - else RequestOutputKind.FINAL_ONLY, - ) - - @staticmethod - def _tokenization_kwargs_from_proto( - params: vllm_engine_pb2.SamplingParams, - ) -> dict[str, int] | None: - if params.HasField("truncate_prompt_tokens"): - return {"truncate_prompt_tokens": params.truncate_prompt_tokens} - return None - - @staticmethod - def _chunk_response(output: RequestOutput) -> vllm_engine_pb2.GenerateResponse: - """ - Build a streaming chunk response from vLLM output. - When output_kind=DELTA, vLLM returns only new tokens automatically. - - Args: - output: vLLM RequestOutput (with delta tokens when output_kind=DELTA) - - Returns: - GenerateResponse with chunk field set - """ - # Get the completion output (first one if n > 1) - completion = output.outputs[0] if output.outputs else None - - if completion is None: - # Empty chunk - return vllm_engine_pb2.GenerateResponse( - chunk=vllm_engine_pb2.GenerateStreamChunk( - token_ids=[], - prompt_tokens=0, - completion_tokens=0, - cached_tokens=0, - ), - ) - - # When output_kind=DELTA, completion.token_ids contains only new tokens - # vLLM handles the delta logic internally - # completion_tokens = delta count (client will accumulate) - return vllm_engine_pb2.GenerateResponse( - chunk=vllm_engine_pb2.GenerateStreamChunk( - token_ids=completion.token_ids, - prompt_tokens=len(output.prompt_token_ids) - if output.prompt_token_ids - else 0, - completion_tokens=len(completion.token_ids), # Delta count - cached_tokens=output.num_cached_tokens, - ), - ) - - @staticmethod - def _complete_response(output: RequestOutput) -> vllm_engine_pb2.GenerateResponse: - """ - Build a final completion response from vLLM output. - - Args: - output: vLLM RequestOutput (finished=True) - - Returns: - GenerateResponse with complete field set - """ - # Get the completion output (first one if n > 1) - completion = output.outputs[0] if output.outputs else None - - if completion is None: - # Empty completion - return vllm_engine_pb2.GenerateResponse( - complete=vllm_engine_pb2.GenerateComplete( - output_ids=[], - finish_reason="error", - prompt_tokens=0, - completion_tokens=0, - cached_tokens=0, - ), - ) - - # Build complete response - # When streaming (DELTA mode): completion.token_ids will be empty/last delta - # When non-streaming (FINAL_ONLY mode): completion.token_ids has all tokens - # Client will accumulate token counts for streaming - return vllm_engine_pb2.GenerateResponse( - complete=vllm_engine_pb2.GenerateComplete( - output_ids=completion.token_ids, - finish_reason=completion.finish_reason or "stop", - prompt_tokens=len(output.prompt_token_ids) - if output.prompt_token_ids - else 0, - completion_tokens=len(completion.token_ids), - cached_tokens=output.num_cached_tokens, - ), - ) - - async def serve_grpc(args: argparse.Namespace): """ - Main serving function. + Main gRPC serving function. Args: args: Parsed command line arguments @@ -428,7 +65,7 @@ async def serve_grpc(args: argparse.Namespace): # Build vLLM config vllm_config = engine_args.create_engine_config( - usage_context=UsageContext.OPENAI_API_SERVER + usage_context=UsageContext.OPENAI_API_SERVER, ) # Create AsyncLLM @@ -436,7 +73,7 @@ async def serve_grpc(args: argparse.Namespace): vllm_config=vllm_config, usage_context=UsageContext.OPENAI_API_SERVER, enable_log_requests=args.enable_log_requests, - disable_log_stats=args.disable_log_stats_server, + disable_log_stats=args.disable_log_stats, ) # Create servicer @@ -447,6 +84,11 @@ async def serve_grpc(args: argparse.Namespace): options=[ ("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1), + # Tolerate client keepalive pings every 10s (default 300s is too + # strict for non-streaming requests where no DATA frames flow + # during generation) + ("grpc.http2.min_recv_ping_interval_without_data_ms", 10000), + ("grpc.keepalive_permit_without_calls", True), ], ) @@ -461,46 +103,42 @@ async def serve_grpc(args: argparse.Namespace): reflection.enable_server_reflection(service_names, server) # Bind to address - address = f"{args.host}:{args.port}" + host = args.host or "0.0.0.0" + address = f"{host}:{args.port}" server.add_insecure_port(address) - # Start server - await server.start() - logger.info("vLLM gRPC server started on %s", address) - logger.info("Server is ready to accept requests") - - # Handle shutdown signals - loop = asyncio.get_running_loop() - stop_event = asyncio.Event() - - def signal_handler(): - logger.info("Received shutdown signal") - stop_event.set() - - for sig in (signal.SIGTERM, signal.SIGINT): - loop.add_signal_handler(sig, signal_handler) - - # Serve until shutdown signal try: - await stop_event.wait() - except KeyboardInterrupt: - logger.info("Interrupted by user") + # Start server + await server.start() + logger.info("vLLM gRPC server started on %s", address) + logger.info("Server is ready to accept requests") + + # Handle shutdown signals + loop = asyncio.get_running_loop() + stop_event = asyncio.Event() + + def signal_handler(): + logger.info("Received shutdown signal") + stop_event.set() + + for sig in (signal.SIGTERM, signal.SIGINT): + loop.add_signal_handler(sig, signal_handler) + + try: + await stop_event.wait() + except KeyboardInterrupt: + logger.info("Interrupted by user") finally: logger.info("Shutting down vLLM gRPC server...") - - # Stop gRPC server await server.stop(grace=5.0) logger.info("gRPC server stopped") - - # Shutdown AsyncLLM async_llm.shutdown() logger.info("AsyncLLM engine stopped") - logger.info("Shutdown complete") def main(): - """Main entry point.""" + """Main entry point for python -m vllm.entrypoints.grpc_server.""" parser = FlexibleArgumentParser( description="vLLM gRPC Server", ) @@ -518,13 +156,6 @@ def main(): default=50051, help="Port to bind gRPC server to", ) - parser.add_argument( - "--disable-log-stats-server", - action="store_true", - help="Disable stats logging on server side", - ) - - # Add vLLM engine args parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() diff --git a/vllm/grpc/__init__.py b/vllm/grpc/__init__.py deleted file mode 100644 index b59ee96fb..000000000 --- a/vllm/grpc/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -vLLM gRPC protocol definitions. - -This module contains the protocol buffer definitions for vLLM's gRPC API. -The protobuf files are compiled into Python code using grpcio-tools. -""" - -# These imports will be available after protobuf compilation -# from vllm.grpc import vllm_engine_pb2 -# from vllm.grpc import vllm_engine_pb2_grpc - -__all__ = [ - "vllm_engine_pb2", - "vllm_engine_pb2_grpc", -] diff --git a/vllm/grpc/compile_protos.py b/vllm/grpc/compile_protos.py deleted file mode 100755 index 92ad46e16..000000000 --- a/vllm/grpc/compile_protos.py +++ /dev/null @@ -1,94 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Compile vLLM protobuf definitions into Python code. - -This script uses grpcio-tools to generate *_pb2.py, *_pb2_grpc.py, and -*_pb2.pyi (type stubs) files from the vllm_engine.proto definition. - -NOTE: Proto compilation happens automatically during package build (via setup.py). -This script is provided for developers who want to regenerate protos manually, -e.g., after modifying vllm_engine.proto. - -Usage: - python vllm/grpc/compile_protos.py - -Requirements: - pip install grpcio-tools -""" - -import sys -from pathlib import Path - - -def compile_protos(): - """Compile protobuf definitions.""" - # Get the vllm package root directory - script_dir = Path(__file__).parent - vllm_package_root = script_dir.parent.parent # vllm/vllm/grpc -> vllm/ - - proto_file = script_dir / "vllm_engine.proto" - - if not proto_file.exists(): - print(f"Error: Proto file not found at {proto_file}") - return 1 - - print(f"Compiling protobuf: {proto_file}") - print(f"Output directory: {script_dir}") - - # Compile the proto file - # We use vllm/vllm as the proto_path so that the package is vllm.grpc.engine - try: - from grpc_tools import protoc - - result = protoc.main( - [ - "grpc_tools.protoc", - f"--proto_path={vllm_package_root}", - f"--python_out={vllm_package_root}", - f"--grpc_python_out={vllm_package_root}", - f"--pyi_out={vllm_package_root}", # Generate type stubs - str(script_dir / "vllm_engine.proto"), - ] - ) - - if result == 0: - # Add SPDX headers to generated files - spdx_header = ( - "# SPDX-License-Identifier: Apache-2.0\n" - "# SPDX-FileCopyrightText: Copyright contributors to the vLLM project\n" - ) - - for generated_file in [ - script_dir / "vllm_engine_pb2.py", - script_dir / "vllm_engine_pb2_grpc.py", - script_dir / "vllm_engine_pb2.pyi", - ]: - if generated_file.exists(): - content = generated_file.read_text() - if not content.startswith("# SPDX-License-Identifier"): - # Add mypy ignore-errors comment for all generated files - header = spdx_header + "# mypy: ignore-errors\n" - generated_file.write_text(header + content) - - print("✓ Protobuf compilation successful!") - print(f" Generated: {script_dir / 'vllm_engine_pb2.py'}") - print(f" Generated: {script_dir / 'vllm_engine_pb2_grpc.py'}") - print(f" Generated: {script_dir / 'vllm_engine_pb2.pyi'} (type stubs)") - return 0 - else: - print(f"Error: protoc returned {result}") - return result - - except ImportError: - print("Error: grpcio-tools not installed") - print("Install with: pip install grpcio-tools") - return 1 - except Exception as e: - print(f"Error during compilation: {e}") - return 1 - - -if __name__ == "__main__": - sys.exit(compile_protos()) diff --git a/vllm/grpc/vllm_engine.proto b/vllm/grpc/vllm_engine.proto deleted file mode 100644 index bbb1b9b00..000000000 --- a/vllm/grpc/vllm_engine.proto +++ /dev/null @@ -1,195 +0,0 @@ -syntax = "proto3"; - -package vllm.grpc.engine; - -// Service definition for vLLM engine communication -// This protocol is designed for efficient binary communication between -// the Rust router and vLLM Python engine (AsyncLLM). -service VllmEngine { - // Submit a generation request (supports streaming) - rpc Generate(GenerateRequest) returns (stream GenerateResponse); - - // Submit an embedding request - rpc Embed(EmbedRequest) returns (EmbedResponse); - - // Health check - rpc HealthCheck(HealthCheckRequest) returns (HealthCheckResponse); - - // Abort a running request - rpc Abort(AbortRequest) returns (AbortResponse); - - // Get model information - rpc GetModelInfo(GetModelInfoRequest) returns (GetModelInfoResponse); - - // Get server information - rpc GetServerInfo(GetServerInfoRequest) returns (GetServerInfoResponse); -} - -// ===================== -// Common Types -// ===================== - -// Sampling parameters for text generation -message SamplingParams { - optional float temperature = 1; - float top_p = 2; - uint32 top_k = 3; - float min_p = 4; - float frequency_penalty = 5; - float presence_penalty = 6; - float repetition_penalty = 7; - - optional uint32 max_tokens = 8; - uint32 min_tokens = 9; - - repeated string stop = 10; - repeated uint32 stop_token_ids = 11; - - bool skip_special_tokens = 12; - bool spaces_between_special_tokens = 13; - bool ignore_eos = 14; - - uint32 n = 15; // Number of parallel samples - - // Logprobs configuration - optional int32 logprobs = 22; // Number of log probabilities per output token (-1 for all) - optional int32 prompt_logprobs = 23; // Number of log probabilities per prompt token (-1 for all) - - // Additional vLLM fields - optional int32 seed = 24; // Random seed for reproducibility - bool include_stop_str_in_output = 25; // Whether to include stop strings in output - map logit_bias = 26; // Token ID to bias mapping (-100 to 100) - optional int32 truncate_prompt_tokens = 27; // Prompt truncation (-1 for model max) - - // Structured outputs (one of) - matches vLLM's StructuredOutputsParams - oneof constraint { - string json_schema = 16; // JSON schema for structured output - string regex = 17; // Regex pattern - string grammar = 18; // Grammar/EBNF for structured output - string structural_tag = 19; // Structural tag (e.g., Harmony models) - bool json_object = 20; // Force JSON object output - ChoiceConstraint choice = 21; // List of allowed choices - } -} - -// Choice constraint for structured outputs -message ChoiceConstraint { - repeated string choices = 1; -} - -// Pre-tokenized input from Rust router -message TokenizedInput { - string original_text = 1; // For reference/debugging - repeated uint32 input_ids = 2; // Actual token IDs to process -} - -// ===================== -// Generate Request -// ===================== - -message GenerateRequest { - string request_id = 1; - - // Prompt input - oneof input { - TokenizedInput tokenized = 2; - string text = 3; - } - - // Generation parameters (includes logprobs config) - SamplingParams sampling_params = 4; - - // Streaming - bool stream = 5; -} - -// ===================== -// Generate Response -// ===================== - -message GenerateResponse { - oneof response { - GenerateStreamChunk chunk = 1; // For streaming - GenerateComplete complete = 2; // For final/non-streaming - } -} - -message GenerateStreamChunk { - repeated uint32 token_ids = 1; // Incremental tokens - uint32 prompt_tokens = 2; - uint32 completion_tokens = 3; - uint32 cached_tokens = 4; - - // Logprobs support (TODO: implement in Phase 4) - // OutputLogProbs output_logprobs = 5; - // InputLogProbs input_logprobs = 6; // Only in first chunk -} - -message GenerateComplete { - repeated uint32 output_ids = 1; // All output tokens - string finish_reason = 2; // "stop", "length", "abort" - uint32 prompt_tokens = 3; - uint32 completion_tokens = 4; - uint32 cached_tokens = 5; - - // Logprobs support (TODO: implement in Phase 4) - // OutputLogProbs output_logprobs = 6; - // InputLogProbs input_logprobs = 7; -} - -// ===================== -// Embedding Request -// ===================== - -message EmbedRequest { - string request_id = 1; - TokenizedInput tokenized = 2; -} - -message EmbedResponse { - repeated float embedding = 1; - uint32 prompt_tokens = 2; - uint32 embedding_dim = 3; -} - -// ===================== -// Management Operations -// ===================== - -message HealthCheckRequest {} - -message HealthCheckResponse { - bool healthy = 1; - string message = 2; -} - -message AbortRequest { - repeated string request_ids = 1; -} - -message AbortResponse { -} - -// ===================== -// Model and Server Info -// ===================== - -message GetModelInfoRequest {} - -message GetModelInfoResponse { - string model_path = 1; - bool is_generation = 2; - uint32 max_context_length = 3; - uint32 vocab_size = 4; - bool supports_vision = 5; -} - -message GetServerInfoRequest {} - -message GetServerInfoResponse { - uint32 active_requests = 1; - bool is_paused = 2; - double last_receive_timestamp = 3; - double uptime_seconds = 4; - string server_type = 5; // "vllm-grpc" -}