TP/quantization/weight loading refactor part 1 - Simplify parallel linear logic (#1181)

This commit is contained in:
Zhuohan Li
2023-10-02 15:36:09 -07:00
committed by GitHub
parent 84e4e37d14
commit ba0bfd40e2
42 changed files with 819 additions and 1547 deletions

View File

@@ -7,22 +7,22 @@ from vllm.outputs import RequestOutput
class DummyEvent:
def __init__(self):
self._flag = False
self.flag = False
def set(self):
self._flag = True
self.flag = True
def clear(self):
self._flag = False
self.flag = False
def test_request_tracker():
tracker = RequestTracker()
tracker.new_requests_event = DummyEvent()
stream_1 = tracker.add_request("1")
assert tracker.new_requests_event._flag
assert tracker.new_requests_event.flag
new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event._flag
assert not tracker.new_requests_event.flag
assert len(new) == 1
assert new[0]["request_id"] == "1"
assert not finished
@@ -30,9 +30,9 @@ def test_request_tracker():
stream_2 = tracker.add_request("2")
stream_3 = tracker.add_request("3")
assert tracker.new_requests_event._flag
assert tracker.new_requests_event.flag
new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event._flag
assert not tracker.new_requests_event.flag
assert len(new) == 2
assert new[0]["request_id"] == "2"
assert new[1]["request_id"] == "3"
@@ -43,7 +43,7 @@ def test_request_tracker():
# request_ids must be unique
with pytest.raises(KeyError):
tracker.add_request("1")
assert not tracker.new_requests_event._flag
assert not tracker.new_requests_event.flag
tracker.abort_request("1")
new, finished = tracker.get_new_and_finished_requests()
@@ -54,7 +54,7 @@ def test_request_tracker():
stream_4 = tracker.add_request("4")
tracker.abort_request("4")
assert tracker.new_requests_event._flag
assert tracker.new_requests_event.flag
new, finished = tracker.get_new_and_finished_requests()
assert len(finished) == 1
assert "4" in finished
@@ -62,11 +62,11 @@ def test_request_tracker():
assert stream_4.finished
stream_5 = tracker.add_request("5")
assert tracker.new_requests_event._flag
assert tracker.new_requests_event.flag
tracker.process_request_output(
RequestOutput("2", "output", [], [], finished=True))
new, finished = tracker.get_new_and_finished_requests()
assert not tracker.new_requests_event._flag
assert not tracker.new_requests_event.flag
assert len(finished) == 1
assert "2" in finished
assert len(new) == 1