429 lines
13 KiB
Python
429 lines
13 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""
|
|
End-to-end tests for the vLLM gRPC server.
|
|
"""
|
|
|
|
import asyncio
|
|
import socket
|
|
import subprocess
|
|
import sys
|
|
import time
|
|
|
|
import grpc
|
|
import pytest
|
|
import pytest_asyncio
|
|
|
|
from vllm.grpc import vllm_engine_pb2, vllm_engine_pb2_grpc
|
|
|
|
# Use a small model for fast testing
|
|
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
|
|
|
|
|
|
def find_free_port() -> int:
|
|
"""Find a free port on localhost."""
|
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
|
s.bind(("", 0))
|
|
s.listen(1)
|
|
port = s.getsockname()[1]
|
|
return port
|
|
|
|
|
|
async def wait_for_server(port: int, timeout: float = 60.0) -> bool:
|
|
"""Wait for the gRPC server to be ready by trying health checks."""
|
|
start_time = time.time()
|
|
print("waiting for server to start...")
|
|
while time.time() - start_time < timeout:
|
|
try:
|
|
channel = grpc.aio.insecure_channel(f"localhost:{port}")
|
|
stub = vllm_engine_pb2_grpc.VllmEngineStub(channel)
|
|
request = vllm_engine_pb2.HealthCheckRequest()
|
|
response = await stub.HealthCheck(request, timeout=5.0)
|
|
await channel.close()
|
|
if response.healthy:
|
|
print("server returned healthy=True")
|
|
return True
|
|
except Exception:
|
|
await asyncio.sleep(0.5)
|
|
return False
|
|
|
|
|
|
class GrpcServerProcess:
|
|
"""Manages a gRPC server running in a subprocess."""
|
|
|
|
def __init__(self):
|
|
self.process: subprocess.Popen | None = None
|
|
self.port: int | None = None
|
|
|
|
async def start(self):
|
|
"""Start the gRPC server process."""
|
|
self.port = find_free_port()
|
|
|
|
# Start the server as a subprocess
|
|
self.process = subprocess.Popen(
|
|
[
|
|
sys.executable,
|
|
"-m",
|
|
"vllm.entrypoints.grpc_server",
|
|
"--model",
|
|
MODEL_NAME,
|
|
"--host",
|
|
"localhost",
|
|
"--port",
|
|
str(self.port),
|
|
"--max-num-batched-tokens",
|
|
"512",
|
|
"--disable-log-stats-server",
|
|
],
|
|
)
|
|
|
|
# Wait for server to be ready
|
|
if not await wait_for_server(self.port):
|
|
self.stop()
|
|
raise RuntimeError("gRPC server failed to start within timeout")
|
|
|
|
def stop(self):
|
|
"""Stop the gRPC server process."""
|
|
if self.process:
|
|
self.process.terminate()
|
|
try:
|
|
self.process.wait(timeout=10)
|
|
except subprocess.TimeoutExpired:
|
|
self.process.kill()
|
|
self.process.wait()
|
|
|
|
|
|
@pytest_asyncio.fixture(scope="module")
|
|
async def grpc_server():
|
|
"""Fixture providing a running gRPC server in a subprocess."""
|
|
server = GrpcServerProcess()
|
|
await server.start()
|
|
|
|
yield server
|
|
|
|
server.stop()
|
|
|
|
|
|
@pytest_asyncio.fixture
|
|
async def grpc_client(grpc_server):
|
|
"""Fixture providing a gRPC client connected to the server."""
|
|
channel = grpc.aio.insecure_channel(f"localhost:{grpc_server.port}")
|
|
stub = vllm_engine_pb2_grpc.VllmEngineStub(channel)
|
|
|
|
yield stub
|
|
|
|
await channel.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_check(grpc_client):
|
|
"""Test the HealthCheck RPC."""
|
|
request = vllm_engine_pb2.HealthCheckRequest()
|
|
response = await grpc_client.HealthCheck(request)
|
|
|
|
assert response.healthy is True
|
|
assert response.message == "Health"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_model_info(grpc_client):
|
|
"""Test the GetModelInfo RPC."""
|
|
request = vllm_engine_pb2.GetModelInfoRequest()
|
|
response = await grpc_client.GetModelInfo(request)
|
|
|
|
assert response.model_path == MODEL_NAME
|
|
assert response.is_generation is True
|
|
assert response.max_context_length > 0
|
|
assert response.vocab_size > 0
|
|
assert response.supports_vision is False
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_server_info(grpc_client):
|
|
"""Test the GetServerInfo RPC."""
|
|
request = vllm_engine_pb2.GetServerInfoRequest()
|
|
response = await grpc_client.GetServerInfo(request)
|
|
|
|
assert response.active_requests >= 0
|
|
assert response.is_paused is False
|
|
assert response.uptime_seconds >= 0
|
|
assert response.server_type == "vllm-grpc"
|
|
assert response.last_receive_timestamp > 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_non_streaming(grpc_client):
|
|
"""Test the Generate RPC in non-streaming mode."""
|
|
# Create a simple request
|
|
request = vllm_engine_pb2.GenerateRequest(
|
|
request_id="test-non-streaming-1",
|
|
tokenized=vllm_engine_pb2.TokenizedInput(
|
|
original_text="Hello, my name is",
|
|
input_ids=[15496, 11, 616, 1438, 318], # GPT-2 tokens for the prompt
|
|
),
|
|
sampling_params=vllm_engine_pb2.SamplingParams(
|
|
temperature=0.0,
|
|
max_tokens=10,
|
|
n=1,
|
|
),
|
|
stream=False,
|
|
)
|
|
|
|
# Collect all responses
|
|
responses = []
|
|
async for response in grpc_client.Generate(request):
|
|
responses.append(response)
|
|
|
|
# Should have exactly one response (complete)
|
|
assert len(responses) == 1
|
|
|
|
# Check the response
|
|
final_response = responses[0]
|
|
assert final_response.HasField("complete")
|
|
|
|
complete = final_response.complete
|
|
assert len(complete.output_ids) > 0
|
|
assert complete.finish_reason in ["stop", "length"]
|
|
assert complete.prompt_tokens > 0
|
|
assert complete.completion_tokens > 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_streaming(grpc_client):
|
|
"""Test the Generate RPC in streaming mode."""
|
|
request = vllm_engine_pb2.GenerateRequest(
|
|
request_id="test-streaming-1",
|
|
tokenized=vllm_engine_pb2.TokenizedInput(
|
|
original_text="The capital of France is",
|
|
input_ids=[464, 3139, 286, 4881, 318], # GPT-2 tokens
|
|
),
|
|
sampling_params=vllm_engine_pb2.SamplingParams(
|
|
temperature=0.0, max_tokens=10, n=1
|
|
),
|
|
stream=True,
|
|
)
|
|
|
|
# Collect all responses
|
|
chunks = []
|
|
complete_response = None
|
|
|
|
async for response in grpc_client.Generate(request):
|
|
if response.HasField("chunk"):
|
|
chunks.append(response.chunk)
|
|
elif response.HasField("complete"):
|
|
complete_response = response.complete
|
|
|
|
# Should have received some chunks
|
|
assert len(chunks) >= 0 # May have 0 chunks if generation is very fast
|
|
|
|
# Should have a final complete response
|
|
assert complete_response is not None
|
|
assert complete_response.finish_reason in ["stop", "length"]
|
|
assert complete_response.prompt_tokens > 0
|
|
|
|
# Verify chunk structure
|
|
for chunk in chunks:
|
|
assert chunk.prompt_tokens > 0
|
|
assert chunk.completion_tokens >= 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_with_different_sampling_params(grpc_client):
|
|
"""Test Generate with various sampling parameters."""
|
|
# Test with temperature
|
|
request = vllm_engine_pb2.GenerateRequest(
|
|
request_id="test-sampling-temp",
|
|
tokenized=vllm_engine_pb2.TokenizedInput(
|
|
original_text="Hello",
|
|
input_ids=[15496],
|
|
),
|
|
sampling_params=vllm_engine_pb2.SamplingParams(
|
|
temperature=0.8, top_p=0.95, max_tokens=5
|
|
),
|
|
stream=False,
|
|
)
|
|
|
|
responses = [r async for r in grpc_client.Generate(request)]
|
|
assert len(responses) == 1
|
|
assert responses[0].HasField("complete")
|
|
|
|
# Test with top_k
|
|
request = vllm_engine_pb2.GenerateRequest(
|
|
request_id="test-sampling-topk",
|
|
tokenized=vllm_engine_pb2.TokenizedInput(
|
|
original_text="Hello",
|
|
input_ids=[15496],
|
|
),
|
|
sampling_params=vllm_engine_pb2.SamplingParams(
|
|
temperature=1.0, top_k=50, max_tokens=5
|
|
),
|
|
stream=False,
|
|
)
|
|
|
|
responses = [r async for r in grpc_client.Generate(request)]
|
|
assert len(responses) == 1
|
|
assert responses[0].HasField("complete")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_with_stop_strings(grpc_client):
|
|
"""Test Generate with stop strings."""
|
|
request = vllm_engine_pb2.GenerateRequest(
|
|
request_id="test-stop-strings",
|
|
tokenized=vllm_engine_pb2.TokenizedInput(
|
|
original_text="Hello",
|
|
input_ids=[15496],
|
|
),
|
|
sampling_params=vllm_engine_pb2.SamplingParams(
|
|
temperature=0.0,
|
|
max_tokens=20,
|
|
stop=["\n", "END"],
|
|
),
|
|
stream=False,
|
|
)
|
|
|
|
responses = [r async for r in grpc_client.Generate(request)]
|
|
assert len(responses) == 1
|
|
assert responses[0].HasField("complete")
|
|
|
|
complete = responses[0].complete
|
|
assert complete.finish_reason in ["stop", "length"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_multiple_requests(grpc_client):
|
|
"""Test handling multiple concurrent Generate requests."""
|
|
|
|
async def make_request(request_id: str):
|
|
request = vllm_engine_pb2.GenerateRequest(
|
|
request_id=request_id,
|
|
tokenized=vllm_engine_pb2.TokenizedInput(
|
|
original_text="Hello",
|
|
input_ids=[15496],
|
|
),
|
|
sampling_params=vllm_engine_pb2.SamplingParams(
|
|
temperature=0.0, max_tokens=5
|
|
),
|
|
stream=False,
|
|
)
|
|
|
|
responses = [r async for r in grpc_client.Generate(request)]
|
|
return responses[0]
|
|
|
|
# Send multiple requests concurrently
|
|
tasks = [make_request(f"test-concurrent-{i}") for i in range(3)]
|
|
responses = await asyncio.gather(*tasks)
|
|
|
|
# Verify all requests completed successfully
|
|
assert len(responses) == 3
|
|
for i, response in enumerate(responses):
|
|
assert response.HasField("complete")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_with_seed(grpc_client):
|
|
"""Test Generate with a fixed seed for reproducibility."""
|
|
|
|
def make_request(request_id: str, seed: int):
|
|
return vllm_engine_pb2.GenerateRequest(
|
|
request_id=request_id,
|
|
tokenized=vllm_engine_pb2.TokenizedInput(
|
|
original_text="The future of AI is",
|
|
input_ids=[464, 2003, 286, 9552, 318],
|
|
),
|
|
sampling_params=vllm_engine_pb2.SamplingParams(
|
|
temperature=1.0, max_tokens=10, seed=seed
|
|
),
|
|
stream=False,
|
|
)
|
|
|
|
# Make two requests with the same seed
|
|
request1 = make_request("test-seed-1", 42)
|
|
request2 = make_request("test-seed-2", 42)
|
|
|
|
response_list1 = [r async for r in grpc_client.Generate(request1)]
|
|
response_list2 = [r async for r in grpc_client.Generate(request2)]
|
|
|
|
# Both should complete successfully
|
|
assert len(response_list1) == 1
|
|
assert len(response_list2) == 1
|
|
assert response_list1[0].HasField("complete")
|
|
assert response_list2[0].HasField("complete")
|
|
|
|
# With the same seed, outputs should be identical
|
|
output_ids1 = list(response_list1[0].complete.output_ids)
|
|
output_ids2 = list(response_list2[0].complete.output_ids)
|
|
assert output_ids1 == output_ids2
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_error_handling(grpc_client):
|
|
"""Test error handling in Generate RPC."""
|
|
# Request with invalid top_p value (-33)
|
|
request = vllm_engine_pb2.GenerateRequest(
|
|
request_id="test-error-invalid-topp",
|
|
sampling_params=vllm_engine_pb2.SamplingParams(
|
|
temperature=0.0, max_tokens=10, top_p=-33
|
|
),
|
|
stream=False,
|
|
)
|
|
|
|
# Should raise an error response
|
|
with pytest.raises(grpc.RpcError) as exc_info:
|
|
_ = [r async for r in grpc_client.Generate(request)]
|
|
|
|
assert exc_info.value.code() == grpc.StatusCode.INVALID_ARGUMENT
|
|
assert "top_p must be in (0, 1], got -33.0" in exc_info.value.details()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_abort_request(grpc_client):
|
|
"""Test the out-of-band Abort RPC."""
|
|
request_id = "test-abort-1"
|
|
|
|
# Start a long-running streaming generate request
|
|
generate_request = vllm_engine_pb2.GenerateRequest(
|
|
request_id=request_id,
|
|
tokenized=vllm_engine_pb2.TokenizedInput(
|
|
original_text="Hello",
|
|
input_ids=[15496],
|
|
),
|
|
sampling_params=vllm_engine_pb2.SamplingParams(
|
|
temperature=0.0,
|
|
min_tokens=500,
|
|
max_tokens=500, # Request many tokens to ensure it runs long enough
|
|
),
|
|
stream=True,
|
|
)
|
|
|
|
# Track whether we were aborted
|
|
was_aborted = False
|
|
received_chunks = 0
|
|
|
|
async def run_generate():
|
|
nonlocal was_aborted, received_chunks
|
|
async for response in grpc_client.Generate(generate_request):
|
|
if response.HasField("chunk"):
|
|
received_chunks += 1
|
|
|
|
if response.HasField("complete"):
|
|
complete = response.complete
|
|
was_aborted = complete.finish_reason == "abort"
|
|
else:
|
|
was_aborted = False
|
|
|
|
async def abort_after_delay():
|
|
# Small delay to ensure generate has started
|
|
await asyncio.sleep(0.1)
|
|
abort_request = vllm_engine_pb2.AbortRequest(request_ids=[request_id])
|
|
await grpc_client.Abort(abort_request)
|
|
|
|
# Run generate and abort concurrently
|
|
await asyncio.gather(run_generate(), abort_after_delay())
|
|
|
|
# The request should have been aborted (received final chunk with
|
|
# "abort" finish reason) and finished early due to the abort.
|
|
assert was_aborted and received_chunks < 500, (
|
|
"Request should have been aborted before generating all 500 tokens"
|
|
)
|