Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk> Signed-off-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
173 lines
5.5 KiB
Python
173 lines
5.5 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import asyncio
|
|
from collections.abc import AsyncGenerator
|
|
from unittest.mock import AsyncMock, MagicMock
|
|
|
|
import pytest
|
|
|
|
from vllm.engine.protocol import StreamingInput
|
|
from vllm.outputs import RequestOutput
|
|
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
|
from vllm.v1.engine.async_llm import AsyncLLM
|
|
from vllm.v1.engine.output_processor import RequestOutputCollector
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_async_llm():
|
|
"""Create a mock AsyncLLM with mocked dependencies."""
|
|
# Create a minimal mock without initializing the full engine
|
|
llm = MagicMock(spec=AsyncLLM)
|
|
|
|
# Mock the essential attributes
|
|
llm.vllm_config = MagicMock()
|
|
llm.vllm_config.cache_config.kv_sharing_fast_prefill = False
|
|
llm.model_config = MagicMock()
|
|
llm.model_config.max_model_len = 2048
|
|
llm.log_requests = False
|
|
llm.errored = False
|
|
llm._pause_cond = asyncio.Condition()
|
|
llm._paused = False
|
|
|
|
# Mock methods
|
|
llm._run_output_handler = MagicMock()
|
|
llm.abort = AsyncMock()
|
|
|
|
# Use the real generate method from AsyncLLM
|
|
llm.generate = AsyncLLM.generate.__get__(llm, AsyncLLM)
|
|
|
|
return llm
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_normal_flow(mock_async_llm):
|
|
"""Test normal generation flow with streaming requests."""
|
|
request_id = "test_request"
|
|
prompt = "Tell me about Paris"
|
|
sampling_params = SamplingParams(max_tokens=10)
|
|
|
|
# Create a mock queue with outputs
|
|
queue = RequestOutputCollector(RequestOutputKind.FINAL_ONLY, request_id)
|
|
output1 = RequestOutput(
|
|
request_id=request_id,
|
|
prompt="Tell me about Paris",
|
|
prompt_token_ids=[1, 2, 3],
|
|
prompt_logprobs=None,
|
|
outputs=[],
|
|
finished=False,
|
|
)
|
|
output2 = RequestOutput(
|
|
request_id=request_id,
|
|
prompt="Tell me about Paris",
|
|
prompt_token_ids=[1, 2, 3],
|
|
prompt_logprobs=None,
|
|
outputs=[],
|
|
finished=True,
|
|
)
|
|
|
|
# Feed outputs to queue as they're consumed to avoid aggregation
|
|
async def feed_outputs():
|
|
queue.put(output1)
|
|
await asyncio.sleep(1) # Let first output be consumed
|
|
queue.put(output2)
|
|
|
|
asyncio.create_task(feed_outputs()) # noqa
|
|
|
|
# Mock add_request to return the queue
|
|
async def mock_add_request(*args, **kwargs):
|
|
return queue
|
|
|
|
mock_async_llm.add_request = mock_add_request
|
|
|
|
# Collect outputs from generate
|
|
outputs = []
|
|
async for output in mock_async_llm.generate(
|
|
prompt=prompt,
|
|
sampling_params=sampling_params,
|
|
request_id=request_id,
|
|
):
|
|
outputs.append(output)
|
|
|
|
assert len(outputs) == 2
|
|
assert outputs[0].finished is False
|
|
assert outputs[1].finished is True
|
|
|
|
|
|
def make_output(request_id: str, finished: bool) -> RequestOutput:
|
|
"""Helper to create a RequestOutput."""
|
|
return RequestOutput(
|
|
request_id=request_id,
|
|
prompt="test",
|
|
prompt_token_ids=[1, 2, 3],
|
|
prompt_logprobs=None,
|
|
outputs=[],
|
|
finished=finished,
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generate_with_async_generator():
|
|
"""Test generate with an async input generator.
|
|
|
|
With the new streaming input API, completion is signaled by finishing
|
|
the input generator (not via a resumable flag). Each input chunk
|
|
produces intermediate outputs, and the final output has finished=True.
|
|
"""
|
|
request_id = "test"
|
|
sampling_params = SamplingParams(max_tokens=10)
|
|
|
|
llm = MagicMock(spec=AsyncLLM)
|
|
llm.vllm_config = MagicMock()
|
|
llm.vllm_config.cache_config.kv_sharing_fast_prefill = False
|
|
llm.model_config = MagicMock()
|
|
llm.model_config.max_model_len = 2048
|
|
llm.log_requests = False
|
|
llm.errored = False
|
|
llm._pause_cond = asyncio.Condition()
|
|
llm._paused = False
|
|
llm._run_output_handler = MagicMock()
|
|
llm.abort = AsyncMock()
|
|
|
|
# Bind the real generate method
|
|
llm.generate = AsyncLLM.generate.__get__(llm, AsyncLLM)
|
|
|
|
# Track inputs processed
|
|
inputs_received = []
|
|
queue = RequestOutputCollector(RequestOutputKind.DELTA, request_id)
|
|
|
|
async def mock_add_request(req_id, prompt, params, *args, **kwargs):
|
|
# When prompt is an AsyncGenerator, process streaming inputs
|
|
if isinstance(prompt, AsyncGenerator):
|
|
# Process inputs in background, produce outputs
|
|
async def handle_stream():
|
|
async for input_chunk in prompt:
|
|
inputs_received.append(input_chunk.prompt)
|
|
# Each input produces an intermediate output
|
|
queue.put(make_output(req_id, finished=False))
|
|
await asyncio.sleep(0.01)
|
|
# Final output when stream ends
|
|
queue.put(make_output(req_id, finished=True))
|
|
|
|
asyncio.create_task(handle_stream())
|
|
return queue
|
|
return queue
|
|
|
|
llm.add_request = mock_add_request
|
|
|
|
async def input_generator() -> AsyncGenerator[StreamingInput, None]:
|
|
yield StreamingInput(prompt="Hello", sampling_params=sampling_params)
|
|
yield StreamingInput(prompt=" world", sampling_params=sampling_params)
|
|
|
|
outputs = []
|
|
async for output in llm.generate(input_generator(), sampling_params, request_id):
|
|
outputs.append(output)
|
|
|
|
# Two intermediate outputs + one final output
|
|
assert len(outputs) == 3
|
|
assert outputs[0].finished is False
|
|
assert outputs[1].finished is False
|
|
assert outputs[2].finished is True
|
|
# Both inputs were processed
|
|
assert inputs_received == ["Hello", " world"]
|