[grpc] Support gRPC server entrypoint (#30190)
Signed-off-by: Chang Su <chang.s.su@oracle.com> Signed-off-by: njhill <nickhill123@gmail.com> Signed-off-by: Nick Hill <nickhill123@gmail.com> Co-authored-by: njhill <nickhill123@gmail.com> Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
428
tests/entrypoints/test_grpc_server.py
Normal file
428
tests/entrypoints/test_grpc_server.py
Normal file
@@ -0,0 +1,428 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
End-to-end tests for the vLLM gRPC server.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
import grpc
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from vllm.grpc import vllm_engine_pb2, vllm_engine_pb2_grpc
|
||||
|
||||
# Use a small model for fast testing
|
||||
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
|
||||
|
||||
|
||||
def find_free_port() -> int:
|
||||
"""Find a free port on localhost."""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
s.listen(1)
|
||||
port = s.getsockname()[1]
|
||||
return port
|
||||
|
||||
|
||||
async def wait_for_server(port: int, timeout: float = 30.0) -> bool:
|
||||
"""Wait for the gRPC server to be ready by trying health checks."""
|
||||
start_time = time.time()
|
||||
print("waiting for server to start...")
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
channel = grpc.aio.insecure_channel(f"localhost:{port}")
|
||||
stub = vllm_engine_pb2_grpc.VllmEngineStub(channel)
|
||||
request = vllm_engine_pb2.HealthCheckRequest()
|
||||
response = await stub.HealthCheck(request, timeout=5.0)
|
||||
await channel.close()
|
||||
if response.healthy:
|
||||
print("server returned healthy=True")
|
||||
return True
|
||||
except Exception:
|
||||
await asyncio.sleep(0.5)
|
||||
return False
|
||||
|
||||
|
||||
class GrpcServerProcess:
|
||||
"""Manages a gRPC server running in a subprocess."""
|
||||
|
||||
def __init__(self):
|
||||
self.process: subprocess.Popen | None = None
|
||||
self.port: int | None = None
|
||||
|
||||
async def start(self):
|
||||
"""Start the gRPC server process."""
|
||||
self.port = find_free_port()
|
||||
|
||||
# Start the server as a subprocess
|
||||
self.process = subprocess.Popen(
|
||||
[
|
||||
sys.executable,
|
||||
"-m",
|
||||
"vllm.entrypoints.grpc_server",
|
||||
"--model",
|
||||
MODEL_NAME,
|
||||
"--host",
|
||||
"localhost",
|
||||
"--port",
|
||||
str(self.port),
|
||||
"--max-num-batched-tokens",
|
||||
"512",
|
||||
"--disable-log-stats-server",
|
||||
],
|
||||
)
|
||||
|
||||
# Wait for server to be ready
|
||||
if not await wait_for_server(self.port):
|
||||
self.stop()
|
||||
raise RuntimeError("gRPC server failed to start within timeout")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the gRPC server process."""
|
||||
if self.process:
|
||||
self.process.terminate()
|
||||
try:
|
||||
self.process.wait(timeout=10)
|
||||
except subprocess.TimeoutExpired:
|
||||
self.process.kill()
|
||||
self.process.wait()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="module")
|
||||
async def grpc_server():
|
||||
"""Fixture providing a running gRPC server in a subprocess."""
|
||||
server = GrpcServerProcess()
|
||||
await server.start()
|
||||
|
||||
yield server
|
||||
|
||||
server.stop()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def grpc_client(grpc_server):
|
||||
"""Fixture providing a gRPC client connected to the server."""
|
||||
channel = grpc.aio.insecure_channel(f"localhost:{grpc_server.port}")
|
||||
stub = vllm_engine_pb2_grpc.VllmEngineStub(channel)
|
||||
|
||||
yield stub
|
||||
|
||||
await channel.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check(grpc_client):
|
||||
"""Test the HealthCheck RPC."""
|
||||
request = vllm_engine_pb2.HealthCheckRequest()
|
||||
response = await grpc_client.HealthCheck(request)
|
||||
|
||||
assert response.healthy is True
|
||||
assert response.message == "Health"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_model_info(grpc_client):
|
||||
"""Test the GetModelInfo RPC."""
|
||||
request = vllm_engine_pb2.GetModelInfoRequest()
|
||||
response = await grpc_client.GetModelInfo(request)
|
||||
|
||||
assert response.model_path == MODEL_NAME
|
||||
assert response.is_generation is True
|
||||
assert response.max_context_length > 0
|
||||
assert response.vocab_size > 0
|
||||
assert response.supports_vision is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_server_info(grpc_client):
|
||||
"""Test the GetServerInfo RPC."""
|
||||
request = vllm_engine_pb2.GetServerInfoRequest()
|
||||
response = await grpc_client.GetServerInfo(request)
|
||||
|
||||
assert response.active_requests >= 0
|
||||
assert response.is_paused is False
|
||||
assert response.uptime_seconds >= 0
|
||||
assert response.server_type == "vllm-grpc"
|
||||
assert response.last_receive_timestamp > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_non_streaming(grpc_client):
|
||||
"""Test the Generate RPC in non-streaming mode."""
|
||||
# Create a simple request
|
||||
request = vllm_engine_pb2.GenerateRequest(
|
||||
request_id="test-non-streaming-1",
|
||||
tokenized=vllm_engine_pb2.TokenizedInput(
|
||||
original_text="Hello, my name is",
|
||||
input_ids=[15496, 11, 616, 1438, 318], # GPT-2 tokens for the prompt
|
||||
),
|
||||
sampling_params=vllm_engine_pb2.SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=10,
|
||||
n=1,
|
||||
),
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Collect all responses
|
||||
responses = []
|
||||
async for response in grpc_client.Generate(request):
|
||||
responses.append(response)
|
||||
|
||||
# Should have exactly one response (complete)
|
||||
assert len(responses) == 1
|
||||
|
||||
# Check the response
|
||||
final_response = responses[0]
|
||||
assert final_response.HasField("complete")
|
||||
|
||||
complete = final_response.complete
|
||||
assert len(complete.output_ids) > 0
|
||||
assert complete.finish_reason in ["stop", "length"]
|
||||
assert complete.prompt_tokens > 0
|
||||
assert complete.completion_tokens > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_streaming(grpc_client):
|
||||
"""Test the Generate RPC in streaming mode."""
|
||||
request = vllm_engine_pb2.GenerateRequest(
|
||||
request_id="test-streaming-1",
|
||||
tokenized=vllm_engine_pb2.TokenizedInput(
|
||||
original_text="The capital of France is",
|
||||
input_ids=[464, 3139, 286, 4881, 318], # GPT-2 tokens
|
||||
),
|
||||
sampling_params=vllm_engine_pb2.SamplingParams(
|
||||
temperature=0.0, max_tokens=10, n=1
|
||||
),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Collect all responses
|
||||
chunks = []
|
||||
complete_response = None
|
||||
|
||||
async for response in grpc_client.Generate(request):
|
||||
if response.HasField("chunk"):
|
||||
chunks.append(response.chunk)
|
||||
elif response.HasField("complete"):
|
||||
complete_response = response.complete
|
||||
|
||||
# Should have received some chunks
|
||||
assert len(chunks) >= 0 # May have 0 chunks if generation is very fast
|
||||
|
||||
# Should have a final complete response
|
||||
assert complete_response is not None
|
||||
assert complete_response.finish_reason in ["stop", "length"]
|
||||
assert complete_response.prompt_tokens > 0
|
||||
|
||||
# Verify chunk structure
|
||||
for chunk in chunks:
|
||||
assert chunk.prompt_tokens > 0
|
||||
assert chunk.completion_tokens >= 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_with_different_sampling_params(grpc_client):
|
||||
"""Test Generate with various sampling parameters."""
|
||||
# Test with temperature
|
||||
request = vllm_engine_pb2.GenerateRequest(
|
||||
request_id="test-sampling-temp",
|
||||
tokenized=vllm_engine_pb2.TokenizedInput(
|
||||
original_text="Hello",
|
||||
input_ids=[15496],
|
||||
),
|
||||
sampling_params=vllm_engine_pb2.SamplingParams(
|
||||
temperature=0.8, top_p=0.95, max_tokens=5
|
||||
),
|
||||
stream=False,
|
||||
)
|
||||
|
||||
responses = [r async for r in grpc_client.Generate(request)]
|
||||
assert len(responses) == 1
|
||||
assert responses[0].HasField("complete")
|
||||
|
||||
# Test with top_k
|
||||
request = vllm_engine_pb2.GenerateRequest(
|
||||
request_id="test-sampling-topk",
|
||||
tokenized=vllm_engine_pb2.TokenizedInput(
|
||||
original_text="Hello",
|
||||
input_ids=[15496],
|
||||
),
|
||||
sampling_params=vllm_engine_pb2.SamplingParams(
|
||||
temperature=1.0, top_k=50, max_tokens=5
|
||||
),
|
||||
stream=False,
|
||||
)
|
||||
|
||||
responses = [r async for r in grpc_client.Generate(request)]
|
||||
assert len(responses) == 1
|
||||
assert responses[0].HasField("complete")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_with_stop_strings(grpc_client):
|
||||
"""Test Generate with stop strings."""
|
||||
request = vllm_engine_pb2.GenerateRequest(
|
||||
request_id="test-stop-strings",
|
||||
tokenized=vllm_engine_pb2.TokenizedInput(
|
||||
original_text="Hello",
|
||||
input_ids=[15496],
|
||||
),
|
||||
sampling_params=vllm_engine_pb2.SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=20,
|
||||
stop=["\n", "END"],
|
||||
),
|
||||
stream=False,
|
||||
)
|
||||
|
||||
responses = [r async for r in grpc_client.Generate(request)]
|
||||
assert len(responses) == 1
|
||||
assert responses[0].HasField("complete")
|
||||
|
||||
complete = responses[0].complete
|
||||
assert complete.finish_reason in ["stop", "length"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_multiple_requests(grpc_client):
|
||||
"""Test handling multiple concurrent Generate requests."""
|
||||
|
||||
async def make_request(request_id: str):
|
||||
request = vllm_engine_pb2.GenerateRequest(
|
||||
request_id=request_id,
|
||||
tokenized=vllm_engine_pb2.TokenizedInput(
|
||||
original_text="Hello",
|
||||
input_ids=[15496],
|
||||
),
|
||||
sampling_params=vllm_engine_pb2.SamplingParams(
|
||||
temperature=0.0, max_tokens=5
|
||||
),
|
||||
stream=False,
|
||||
)
|
||||
|
||||
responses = [r async for r in grpc_client.Generate(request)]
|
||||
return responses[0]
|
||||
|
||||
# Send multiple requests concurrently
|
||||
tasks = [make_request(f"test-concurrent-{i}") for i in range(3)]
|
||||
responses = await asyncio.gather(*tasks)
|
||||
|
||||
# Verify all requests completed successfully
|
||||
assert len(responses) == 3
|
||||
for i, response in enumerate(responses):
|
||||
assert response.HasField("complete")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_with_seed(grpc_client):
|
||||
"""Test Generate with a fixed seed for reproducibility."""
|
||||
|
||||
def make_request(request_id: str, seed: int):
|
||||
return vllm_engine_pb2.GenerateRequest(
|
||||
request_id=request_id,
|
||||
tokenized=vllm_engine_pb2.TokenizedInput(
|
||||
original_text="The future of AI is",
|
||||
input_ids=[464, 2003, 286, 9552, 318],
|
||||
),
|
||||
sampling_params=vllm_engine_pb2.SamplingParams(
|
||||
temperature=1.0, max_tokens=10, seed=seed
|
||||
),
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Make two requests with the same seed
|
||||
request1 = make_request("test-seed-1", 42)
|
||||
request2 = make_request("test-seed-2", 42)
|
||||
|
||||
response_list1 = [r async for r in grpc_client.Generate(request1)]
|
||||
response_list2 = [r async for r in grpc_client.Generate(request2)]
|
||||
|
||||
# Both should complete successfully
|
||||
assert len(response_list1) == 1
|
||||
assert len(response_list2) == 1
|
||||
assert response_list1[0].HasField("complete")
|
||||
assert response_list2[0].HasField("complete")
|
||||
|
||||
# With the same seed, outputs should be identical
|
||||
output_ids1 = list(response_list1[0].complete.output_ids)
|
||||
output_ids2 = list(response_list2[0].complete.output_ids)
|
||||
assert output_ids1 == output_ids2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_error_handling(grpc_client):
|
||||
"""Test error handling in Generate RPC."""
|
||||
# Request with invalid top_p value (-33)
|
||||
request = vllm_engine_pb2.GenerateRequest(
|
||||
request_id="test-error-invalid-topp",
|
||||
sampling_params=vllm_engine_pb2.SamplingParams(
|
||||
temperature=0.0, max_tokens=10, top_p=-33
|
||||
),
|
||||
stream=False,
|
||||
)
|
||||
|
||||
# Should raise an error response
|
||||
with pytest.raises(grpc.RpcError) as exc_info:
|
||||
_ = [r async for r in grpc_client.Generate(request)]
|
||||
|
||||
assert exc_info.value.code() == grpc.StatusCode.INVALID_ARGUMENT
|
||||
assert "top_p must be in (0, 1], got -33.0" in exc_info.value.details()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_abort_request(grpc_client):
|
||||
"""Test the out-of-band Abort RPC."""
|
||||
request_id = "test-abort-1"
|
||||
|
||||
# Start a long-running streaming generate request
|
||||
generate_request = vllm_engine_pb2.GenerateRequest(
|
||||
request_id=request_id,
|
||||
tokenized=vllm_engine_pb2.TokenizedInput(
|
||||
original_text="Hello",
|
||||
input_ids=[15496],
|
||||
),
|
||||
sampling_params=vllm_engine_pb2.SamplingParams(
|
||||
temperature=0.0,
|
||||
min_tokens=500,
|
||||
max_tokens=500, # Request many tokens to ensure it runs long enough
|
||||
),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Track whether we were aborted
|
||||
was_aborted = False
|
||||
received_chunks = 0
|
||||
|
||||
async def run_generate():
|
||||
nonlocal was_aborted, received_chunks
|
||||
async for response in grpc_client.Generate(generate_request):
|
||||
if response.HasField("chunk"):
|
||||
received_chunks += 1
|
||||
|
||||
if response.HasField("complete"):
|
||||
complete = response.complete
|
||||
was_aborted = complete.finish_reason == "abort"
|
||||
else:
|
||||
was_aborted = False
|
||||
|
||||
async def abort_after_delay():
|
||||
# Small delay to ensure generate has started
|
||||
await asyncio.sleep(0.1)
|
||||
abort_request = vllm_engine_pb2.AbortRequest(request_ids=[request_id])
|
||||
await grpc_client.Abort(abort_request)
|
||||
|
||||
# Run generate and abort concurrently
|
||||
await asyncio.gather(run_generate(), abort_after_delay())
|
||||
|
||||
# The request should have been aborted (received final chunk with
|
||||
# "abort" finish reason) and finished early due to the abort.
|
||||
assert was_aborted and received_chunks < 500, (
|
||||
"Request should have been aborted before generating all 500 tokens"
|
||||
)
|
||||
Reference in New Issue
Block a user