Files
vllm/tests/entrypoints/test_grpc_server.py

429 lines
13 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
End-to-end tests for the vLLM gRPC server.
"""
import asyncio
import socket
import subprocess
import sys
import time
import grpc
import pytest
import pytest_asyncio
from vllm.grpc import vllm_engine_pb2, vllm_engine_pb2_grpc
# Use a small model for fast testing
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
def find_free_port() -> int:
"""Find a free port on localhost."""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
s.listen(1)
port = s.getsockname()[1]
return port
async def wait_for_server(port: int, timeout: float = 60.0) -> bool:
"""Wait for the gRPC server to be ready by trying health checks."""
start_time = time.time()
print("waiting for server to start...")
while time.time() - start_time < timeout:
try:
channel = grpc.aio.insecure_channel(f"localhost:{port}")
stub = vllm_engine_pb2_grpc.VllmEngineStub(channel)
request = vllm_engine_pb2.HealthCheckRequest()
response = await stub.HealthCheck(request, timeout=5.0)
await channel.close()
if response.healthy:
print("server returned healthy=True")
return True
except Exception:
await asyncio.sleep(0.5)
return False
class GrpcServerProcess:
"""Manages a gRPC server running in a subprocess."""
def __init__(self):
self.process: subprocess.Popen | None = None
self.port: int | None = None
async def start(self):
"""Start the gRPC server process."""
self.port = find_free_port()
# Start the server as a subprocess
self.process = subprocess.Popen(
[
sys.executable,
"-m",
"vllm.entrypoints.grpc_server",
"--model",
MODEL_NAME,
"--host",
"localhost",
"--port",
str(self.port),
"--max-num-batched-tokens",
"512",
"--disable-log-stats-server",
],
)
# Wait for server to be ready
if not await wait_for_server(self.port):
self.stop()
raise RuntimeError("gRPC server failed to start within timeout")
def stop(self):
"""Stop the gRPC server process."""
if self.process:
self.process.terminate()
try:
self.process.wait(timeout=10)
except subprocess.TimeoutExpired:
self.process.kill()
self.process.wait()
@pytest_asyncio.fixture(scope="module")
async def grpc_server():
"""Fixture providing a running gRPC server in a subprocess."""
server = GrpcServerProcess()
await server.start()
yield server
server.stop()
@pytest_asyncio.fixture
async def grpc_client(grpc_server):
"""Fixture providing a gRPC client connected to the server."""
channel = grpc.aio.insecure_channel(f"localhost:{grpc_server.port}")
stub = vllm_engine_pb2_grpc.VllmEngineStub(channel)
yield stub
await channel.close()
@pytest.mark.asyncio
async def test_health_check(grpc_client):
"""Test the HealthCheck RPC."""
request = vllm_engine_pb2.HealthCheckRequest()
response = await grpc_client.HealthCheck(request)
assert response.healthy is True
assert response.message == "Health"
@pytest.mark.asyncio
async def test_get_model_info(grpc_client):
"""Test the GetModelInfo RPC."""
request = vllm_engine_pb2.GetModelInfoRequest()
response = await grpc_client.GetModelInfo(request)
assert response.model_path == MODEL_NAME
assert response.is_generation is True
assert response.max_context_length > 0
assert response.vocab_size > 0
assert response.supports_vision is False
@pytest.mark.asyncio
async def test_get_server_info(grpc_client):
"""Test the GetServerInfo RPC."""
request = vllm_engine_pb2.GetServerInfoRequest()
response = await grpc_client.GetServerInfo(request)
assert response.active_requests >= 0
assert response.is_paused is False
assert response.uptime_seconds >= 0
assert response.server_type == "vllm-grpc"
assert response.last_receive_timestamp > 0
@pytest.mark.asyncio
async def test_generate_non_streaming(grpc_client):
"""Test the Generate RPC in non-streaming mode."""
# Create a simple request
request = vllm_engine_pb2.GenerateRequest(
request_id="test-non-streaming-1",
tokenized=vllm_engine_pb2.TokenizedInput(
original_text="Hello, my name is",
input_ids=[15496, 11, 616, 1438, 318], # GPT-2 tokens for the prompt
),
sampling_params=vllm_engine_pb2.SamplingParams(
temperature=0.0,
max_tokens=10,
n=1,
),
stream=False,
)
# Collect all responses
responses = []
async for response in grpc_client.Generate(request):
responses.append(response)
# Should have exactly one response (complete)
assert len(responses) == 1
# Check the response
final_response = responses[0]
assert final_response.HasField("complete")
complete = final_response.complete
assert len(complete.output_ids) > 0
assert complete.finish_reason in ["stop", "length"]
assert complete.prompt_tokens > 0
assert complete.completion_tokens > 0
@pytest.mark.asyncio
async def test_generate_streaming(grpc_client):
"""Test the Generate RPC in streaming mode."""
request = vllm_engine_pb2.GenerateRequest(
request_id="test-streaming-1",
tokenized=vllm_engine_pb2.TokenizedInput(
original_text="The capital of France is",
input_ids=[464, 3139, 286, 4881, 318], # GPT-2 tokens
),
sampling_params=vllm_engine_pb2.SamplingParams(
temperature=0.0, max_tokens=10, n=1
),
stream=True,
)
# Collect all responses
chunks = []
complete_response = None
async for response in grpc_client.Generate(request):
if response.HasField("chunk"):
chunks.append(response.chunk)
elif response.HasField("complete"):
complete_response = response.complete
# Should have received some chunks
assert len(chunks) >= 0 # May have 0 chunks if generation is very fast
# Should have a final complete response
assert complete_response is not None
assert complete_response.finish_reason in ["stop", "length"]
assert complete_response.prompt_tokens > 0
# Verify chunk structure
for chunk in chunks:
assert chunk.prompt_tokens > 0
assert chunk.completion_tokens >= 0
@pytest.mark.asyncio
async def test_generate_with_different_sampling_params(grpc_client):
"""Test Generate with various sampling parameters."""
# Test with temperature
request = vllm_engine_pb2.GenerateRequest(
request_id="test-sampling-temp",
tokenized=vllm_engine_pb2.TokenizedInput(
original_text="Hello",
input_ids=[15496],
),
sampling_params=vllm_engine_pb2.SamplingParams(
temperature=0.8, top_p=0.95, max_tokens=5
),
stream=False,
)
responses = [r async for r in grpc_client.Generate(request)]
assert len(responses) == 1
assert responses[0].HasField("complete")
# Test with top_k
request = vllm_engine_pb2.GenerateRequest(
request_id="test-sampling-topk",
tokenized=vllm_engine_pb2.TokenizedInput(
original_text="Hello",
input_ids=[15496],
),
sampling_params=vllm_engine_pb2.SamplingParams(
temperature=1.0, top_k=50, max_tokens=5
),
stream=False,
)
responses = [r async for r in grpc_client.Generate(request)]
assert len(responses) == 1
assert responses[0].HasField("complete")
@pytest.mark.asyncio
async def test_generate_with_stop_strings(grpc_client):
"""Test Generate with stop strings."""
request = vllm_engine_pb2.GenerateRequest(
request_id="test-stop-strings",
tokenized=vllm_engine_pb2.TokenizedInput(
original_text="Hello",
input_ids=[15496],
),
sampling_params=vllm_engine_pb2.SamplingParams(
temperature=0.0,
max_tokens=20,
stop=["\n", "END"],
),
stream=False,
)
responses = [r async for r in grpc_client.Generate(request)]
assert len(responses) == 1
assert responses[0].HasField("complete")
complete = responses[0].complete
assert complete.finish_reason in ["stop", "length"]
@pytest.mark.asyncio
async def test_generate_multiple_requests(grpc_client):
"""Test handling multiple concurrent Generate requests."""
async def make_request(request_id: str):
request = vllm_engine_pb2.GenerateRequest(
request_id=request_id,
tokenized=vllm_engine_pb2.TokenizedInput(
original_text="Hello",
input_ids=[15496],
),
sampling_params=vllm_engine_pb2.SamplingParams(
temperature=0.0, max_tokens=5
),
stream=False,
)
responses = [r async for r in grpc_client.Generate(request)]
return responses[0]
# Send multiple requests concurrently
tasks = [make_request(f"test-concurrent-{i}") for i in range(3)]
responses = await asyncio.gather(*tasks)
# Verify all requests completed successfully
assert len(responses) == 3
for i, response in enumerate(responses):
assert response.HasField("complete")
@pytest.mark.asyncio
async def test_generate_with_seed(grpc_client):
"""Test Generate with a fixed seed for reproducibility."""
def make_request(request_id: str, seed: int):
return vllm_engine_pb2.GenerateRequest(
request_id=request_id,
tokenized=vllm_engine_pb2.TokenizedInput(
original_text="The future of AI is",
input_ids=[464, 2003, 286, 9552, 318],
),
sampling_params=vllm_engine_pb2.SamplingParams(
temperature=1.0, max_tokens=10, seed=seed
),
stream=False,
)
# Make two requests with the same seed
request1 = make_request("test-seed-1", 42)
request2 = make_request("test-seed-2", 42)
response_list1 = [r async for r in grpc_client.Generate(request1)]
response_list2 = [r async for r in grpc_client.Generate(request2)]
# Both should complete successfully
assert len(response_list1) == 1
assert len(response_list2) == 1
assert response_list1[0].HasField("complete")
assert response_list2[0].HasField("complete")
# With the same seed, outputs should be identical
output_ids1 = list(response_list1[0].complete.output_ids)
output_ids2 = list(response_list2[0].complete.output_ids)
assert output_ids1 == output_ids2
@pytest.mark.asyncio
async def test_generate_error_handling(grpc_client):
"""Test error handling in Generate RPC."""
# Request with invalid top_p value (-33)
request = vllm_engine_pb2.GenerateRequest(
request_id="test-error-invalid-topp",
sampling_params=vllm_engine_pb2.SamplingParams(
temperature=0.0, max_tokens=10, top_p=-33
),
stream=False,
)
# Should raise an error response
with pytest.raises(grpc.RpcError) as exc_info:
_ = [r async for r in grpc_client.Generate(request)]
assert exc_info.value.code() == grpc.StatusCode.INVALID_ARGUMENT
assert "top_p must be in (0, 1], got -33.0" in exc_info.value.details()
@pytest.mark.asyncio
async def test_abort_request(grpc_client):
"""Test the out-of-band Abort RPC."""
request_id = "test-abort-1"
# Start a long-running streaming generate request
generate_request = vllm_engine_pb2.GenerateRequest(
request_id=request_id,
tokenized=vllm_engine_pb2.TokenizedInput(
original_text="Hello",
input_ids=[15496],
),
sampling_params=vllm_engine_pb2.SamplingParams(
temperature=0.0,
min_tokens=500,
max_tokens=500, # Request many tokens to ensure it runs long enough
),
stream=True,
)
# Track whether we were aborted
was_aborted = False
received_chunks = 0
async def run_generate():
nonlocal was_aborted, received_chunks
async for response in grpc_client.Generate(generate_request):
if response.HasField("chunk"):
received_chunks += 1
if response.HasField("complete"):
complete = response.complete
was_aborted = complete.finish_reason == "abort"
else:
was_aborted = False
async def abort_after_delay():
# Small delay to ensure generate has started
await asyncio.sleep(0.1)
abort_request = vllm_engine_pb2.AbortRequest(request_ids=[request_id])
await grpc_client.Abort(abort_request)
# Run generate and abort concurrently
await asyncio.gather(run_generate(), abort_after_delay())
# The request should have been aborted (received final chunk with
# "abort" finish reason) and finished early due to the abort.
assert was_aborted and received_chunks < 500, (
"Request should have been aborted before generating all 500 tokens"
)