[V1][PP] Optimization: continue scheduling prefill chunks (#17080)
Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
This commit is contained in:
@@ -1,10 +1,9 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import copy
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from concurrent.futures import Future
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
@@ -244,33 +243,33 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
||||
self, kv_cache_configs: list[KVCacheConfig]) -> None:
|
||||
super().initialize_from_config(kv_cache_configs)
|
||||
|
||||
# This executor actually can only run 1 batch at a time
|
||||
self.semaphore = threading.Semaphore(1)
|
||||
# Create a thread pool with a single worker
|
||||
self.thread_pool = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
scheduler_output,
|
||||
) -> Future[ModelRunnerOutput]:
|
||||
"""Make execute_model non-blocking."""
|
||||
future: Future[ModelRunnerOutput] = Future()
|
||||
|
||||
def _thread_wrapper(scheduler_output, future):
|
||||
with self.semaphore:
|
||||
output = self.collective_rpc("execute_model",
|
||||
args=(scheduler_output, ))
|
||||
# Make a copy because output[0] may be reused
|
||||
# by the next batch.
|
||||
output = copy.deepcopy(output[0])
|
||||
future.set_result(output)
|
||||
def _execute():
|
||||
output = self.collective_rpc("execute_model",
|
||||
args=(scheduler_output, ))
|
||||
# Make a copy because output[0] may be reused
|
||||
# by the next batch.
|
||||
return copy.deepcopy(output[0])
|
||||
|
||||
threading.Thread(target=_thread_wrapper,
|
||||
args=(scheduler_output, future)).start()
|
||||
return future
|
||||
# Use the thread pool instead of creating a new thread
|
||||
return self.thread_pool.submit(_execute)
|
||||
|
||||
@property
|
||||
def max_concurrent_batches(self) -> int:
|
||||
return 2
|
||||
|
||||
def shutdown(self):
|
||||
if hasattr(self, 'thread_pool'):
|
||||
self.thread_pool.shutdown(wait=False)
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
@@ -299,14 +298,77 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
||||
# Schedule Batch 1: (10, req0)
|
||||
assert engine_core.step_with_batch_queue() is None
|
||||
assert engine_core.batch_queue.qsize() == 1
|
||||
scheduler_output = engine_core.batch_queue.queue[-1][1]
|
||||
assert scheduler_output.num_scheduled_tokens[0] == 10
|
||||
# num_computed_tokens should have been updated immediately.
|
||||
assert engine_core.scheduler.requests[
|
||||
req0.request_id].num_computed_tokens == 10
|
||||
|
||||
# Schedule Batch 2: (2, req0), (8, req1)
|
||||
assert engine_core.step_with_batch_queue() is None
|
||||
assert engine_core.batch_queue.qsize() == 2
|
||||
scheduler_output = engine_core.batch_queue.queue[-1][1]
|
||||
assert scheduler_output.num_scheduled_tokens[0] == 2
|
||||
assert scheduler_output.num_scheduled_tokens[1] == 8
|
||||
# num_computed_tokens should have been updated immediately.
|
||||
assert engine_core.scheduler.requests[0].num_computed_tokens == 12
|
||||
assert engine_core.scheduler.requests[1].num_computed_tokens == 8
|
||||
|
||||
assert engine_core.scheduler.get_num_unfinished_requests() == 2
|
||||
|
||||
# Loop through both requests.
|
||||
while engine_core.scheduler.get_num_unfinished_requests() == 2:
|
||||
engine_core.step_with_batch_queue()
|
||||
# Batch queue is full. Finish Batch 1.
|
||||
engine_core.step_with_batch_queue()
|
||||
|
||||
# Reaching here when got the result of the first request.
|
||||
while engine_core.scheduler.get_num_unfinished_requests() == 1:
|
||||
engine_core.step_with_batch_queue()
|
||||
# Schedule Batch 3: (4, req1). Note that req0 cannot be scheduled
|
||||
# because it is in the decoding stage now.
|
||||
engine_core.step_with_batch_queue()
|
||||
assert engine_core.batch_queue.qsize() == 2
|
||||
scheduler_output = engine_core.batch_queue.queue[-1][1]
|
||||
assert scheduler_output.num_scheduled_tokens[1] == 4
|
||||
|
||||
# Batch queue is full. Finish Batch 2. Get first token of req0.
|
||||
output = engine_core.step_with_batch_queue()
|
||||
assert output is not None
|
||||
assert len(output.outputs) == 1
|
||||
assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13
|
||||
|
||||
# Schedule Batch 4: (1, req0).
|
||||
engine_core.step_with_batch_queue()
|
||||
assert engine_core.batch_queue.qsize() == 2
|
||||
scheduler_output = engine_core.batch_queue.queue[-1][1]
|
||||
assert scheduler_output.num_scheduled_tokens[0] == 1
|
||||
|
||||
# Batch queue is full. Finish Batch 3. Get first token of req1.
|
||||
output = engine_core.step_with_batch_queue()
|
||||
assert output is not None
|
||||
assert len(output.outputs) == 1
|
||||
assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13
|
||||
|
||||
# Schedule Batch 5: (1, req1).
|
||||
engine_core.step_with_batch_queue()
|
||||
assert engine_core.batch_queue.qsize() == 2
|
||||
scheduler_output = engine_core.batch_queue.queue[-1][1]
|
||||
assert scheduler_output.num_scheduled_tokens[1] == 1
|
||||
|
||||
# Loop until req0 is finished.
|
||||
step = 0
|
||||
req_id = 0
|
||||
expected_num_tokens = [
|
||||
engine_core.scheduler.requests[0].num_tokens + 1,
|
||||
engine_core.scheduler.requests[1].num_tokens + 1,
|
||||
]
|
||||
while engine_core.scheduler.get_num_unfinished_requests() == 2:
|
||||
output = engine_core.step_with_batch_queue()
|
||||
if step % 2 == 0:
|
||||
# Even steps consumes an output.
|
||||
assert output is not None
|
||||
assert len(output.outputs) == 1
|
||||
if req_id in engine_core.scheduler.requests:
|
||||
assert engine_core.scheduler.requests[
|
||||
req_id].num_tokens == expected_num_tokens[req_id]
|
||||
expected_num_tokens[req_id] += 1
|
||||
req_id = (req_id + 1) % 2
|
||||
else:
|
||||
# Odd steps schedules a new batch.
|
||||
assert output is None
|
||||
step += 1
|
||||
|
||||
Reference in New Issue
Block a user