Remove AsyncLLMEngine busy loop, shield background task (#1059)
This commit is contained in:
80
tests/async_engine/test_async_llm_engine.py
Normal file
80
tests/async_engine/test_async_llm_engine.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestOutput:
|
||||
request_id: int
|
||||
finished: bool = False
|
||||
|
||||
|
||||
class MockEngine:
|
||||
|
||||
def __init__(self):
|
||||
self.step_calls = 0
|
||||
self.add_request_calls = 0
|
||||
self.abort_request_calls = 0
|
||||
self.request_id = None
|
||||
|
||||
async def step_async(self):
|
||||
self.step_calls += 1
|
||||
return [RequestOutput(
|
||||
request_id=self.request_id)] if self.request_id else []
|
||||
|
||||
def generate(self, request_id):
|
||||
self.request_id = request_id
|
||||
|
||||
def stop_generating(self):
|
||||
self.request_id = None
|
||||
|
||||
def add_request(self, **kwargs):
|
||||
self.add_request_calls += 1
|
||||
return
|
||||
|
||||
def abort_request(self, request_id):
|
||||
self.abort_request_calls += 1
|
||||
return
|
||||
|
||||
|
||||
class MockAsyncLLMEngine(AsyncLLMEngine):
|
||||
|
||||
def _init_engine(self, *args, **kwargs):
|
||||
return MockEngine()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_requests_event():
|
||||
engine = MockAsyncLLMEngine(worker_use_ray=False, engine_use_ray=False)
|
||||
engine.start_background_loop()
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.step_calls == 0
|
||||
|
||||
await engine.add_request("1", "", None)
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.add_request_calls == 1
|
||||
assert engine.engine.step_calls == 1
|
||||
|
||||
await engine.add_request("2", "", None)
|
||||
engine.engine.generate("2")
|
||||
await asyncio.sleep(0)
|
||||
assert engine.engine.add_request_calls == 2
|
||||
assert engine.engine.step_calls == 2
|
||||
await asyncio.sleep(0)
|
||||
assert engine.engine.step_calls == 3
|
||||
engine.engine.stop_generating()
|
||||
await asyncio.sleep(0)
|
||||
assert engine.engine.step_calls == 4
|
||||
await asyncio.sleep(0)
|
||||
assert engine.engine.step_calls == 4
|
||||
|
||||
await engine.add_request("3", "", None)
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.add_request_calls == 3
|
||||
assert engine.engine.step_calls == 5
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.add_request_calls == 3
|
||||
assert engine.engine.step_calls == 5
|
||||
@@ -4,10 +4,25 @@ from vllm.engine.async_llm_engine import RequestTracker
|
||||
from vllm.outputs import RequestOutput
|
||||
|
||||
|
||||
class DummyEvent:
|
||||
|
||||
def __init__(self):
|
||||
self._flag = False
|
||||
|
||||
def set(self):
|
||||
self._flag = True
|
||||
|
||||
def clear(self):
|
||||
self._flag = False
|
||||
|
||||
|
||||
def test_request_tracker():
|
||||
tracker = RequestTracker()
|
||||
tracker.new_requests_event = DummyEvent()
|
||||
stream_1 = tracker.add_request("1")
|
||||
assert tracker.new_requests_event._flag
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert not tracker.new_requests_event._flag
|
||||
assert len(new) == 1
|
||||
assert new[0]["request_id"] == "1"
|
||||
assert not finished
|
||||
@@ -15,7 +30,9 @@ def test_request_tracker():
|
||||
|
||||
stream_2 = tracker.add_request("2")
|
||||
stream_3 = tracker.add_request("3")
|
||||
assert tracker.new_requests_event._flag
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert not tracker.new_requests_event._flag
|
||||
assert len(new) == 2
|
||||
assert new[0]["request_id"] == "2"
|
||||
assert new[1]["request_id"] == "3"
|
||||
@@ -26,6 +43,7 @@ def test_request_tracker():
|
||||
# request_ids must be unique
|
||||
with pytest.raises(KeyError):
|
||||
tracker.add_request("1")
|
||||
assert not tracker.new_requests_event._flag
|
||||
|
||||
tracker.abort_request("1")
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
@@ -36,6 +54,7 @@ def test_request_tracker():
|
||||
|
||||
stream_4 = tracker.add_request("4")
|
||||
tracker.abort_request("4")
|
||||
assert tracker.new_requests_event._flag
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert len(finished) == 1
|
||||
assert "4" in finished
|
||||
@@ -43,9 +62,11 @@ def test_request_tracker():
|
||||
assert stream_4.finished
|
||||
|
||||
stream_5 = tracker.add_request("5")
|
||||
assert tracker.new_requests_event._flag
|
||||
tracker.process_request_output(
|
||||
RequestOutput("2", "output", [], [], finished=True))
|
||||
new, finished = tracker.get_new_and_finished_requests()
|
||||
assert not tracker.new_requests_event._flag
|
||||
assert len(finished) == 1
|
||||
assert "2" in finished
|
||||
assert len(new) == 1
|
||||
|
||||
Reference in New Issue
Block a user