TP/quantization/weight loading refactor part 1 - Simplify parallel linear logic (#1181)

This commit is contained in:
Zhuohan Li
2023-10-02 15:36:09 -07:00
committed by GitHub
parent 84e4e37d14
commit ba0bfd40e2
42 changed files with 819 additions and 1547 deletions

View File

@@ -14,6 +14,7 @@ app = vllm.entrypoints.api_server.app
class AsyncLLMEngineWithStats(AsyncLLMEngine):
# pylint: disable=redefined-outer-name
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._num_aborts = 0

View File

@@ -24,6 +24,7 @@ def _query_server(prompt: str) -> dict:
def api_server():
script_path = Path(__file__).parent.joinpath(
"api_server_async_engine.py").absolute()
# pylint: disable=consider-using-with
uvicorn_process = subprocess.Popen([
sys.executable, "-u",
str(script_path), "--model", "facebook/opt-125m"
@@ -32,6 +33,7 @@ def api_server():
uvicorn_process.terminate()
# pylint: disable=redefined-outer-name, unused-argument
def test_api_server(api_server):
"""
Run the API server and test it.
@@ -47,6 +49,7 @@ def test_api_server(api_server):
prompts = ["Hello world"] * 1
result = None
while not result:
# pylint: disable=bare-except
try:
for result in pool.map(_query_server, prompts):
break

View File

@@ -32,12 +32,12 @@ class MockEngine:
self.request_id = None
def add_request(self, **kwargs):
del kwargs # Unused
self.add_request_calls += 1
return
def abort_request(self, request_id):
del request_id # Unused
self.abort_request_calls += 1
return
class MockAsyncLLMEngine(AsyncLLMEngine):

View File

@@ -7,22 +7,22 @@ from vllm.outputs import RequestOutput
class DummyEvent:
def __init__(self):
self._flag = False
self.flag = False
def set(self):
self._flag = True
self.flag = True
def clear(self):
self._flag = False
self.flag = False
def test_request_tracker():
tracker = RequestTracker()
tracker.new_requests_event = DummyEvent()
stream_1 = tracker.add_request("1")
assert tracker.new_requests_event._flag
assert tracker.new_requests_event.flag
new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event._flag
assert not tracker.new_requests_event.flag
assert len(new) == 1
assert new[0]["request_id"] == "1"
assert not finished
@@ -30,9 +30,9 @@ def test_request_tracker():
stream_2 = tracker.add_request("2")
stream_3 = tracker.add_request("3")
assert tracker.new_requests_event._flag
assert tracker.new_requests_event.flag
new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event._flag
assert not tracker.new_requests_event.flag
assert len(new) == 2
assert new[0]["request_id"] == "2"
assert new[1]["request_id"] == "3"
@@ -43,7 +43,7 @@ def test_request_tracker():
# request_ids must be unique
with pytest.raises(KeyError):
tracker.add_request("1")
assert not tracker.new_requests_event._flag
assert not tracker.new_requests_event.flag
tracker.abort_request("1")
new, finished = tracker.get_new_and_finished_requests()
@@ -54,7 +54,7 @@ def test_request_tracker():
stream_4 = tracker.add_request("4")
tracker.abort_request("4")
assert tracker.new_requests_event._flag
assert tracker.new_requests_event.flag
new, finished = tracker.get_new_and_finished_requests()
assert len(finished) == 1
assert "4" in finished
@@ -62,11 +62,11 @@ def test_request_tracker():
assert stream_4.finished
stream_5 = tracker.add_request("5")
assert tracker.new_requests_event._flag
assert tracker.new_requests_event.flag
tracker.process_request_output(
RequestOutput("2", "output", [], [], finished=True))
new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event._flag
assert not tracker.new_requests_event.flag
assert len(finished) == 1
assert "2" in finished
assert len(new) == 1