[V1][PP] Optimization: continue scheduling prefill chunks (#17080)

Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
This commit is contained in:
Rui Qiao
2025-04-24 05:27:08 -07:00
committed by GitHub
parent a9138e85b1
commit c0dfd97519
5 changed files with 128 additions and 74 deletions

View File

@@ -1,10 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
import copy
import threading
import time
import uuid
from concurrent.futures import Future
from concurrent.futures import Future, ThreadPoolExecutor
import pytest
from transformers import AutoTokenizer
@@ -244,33 +243,33 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
self, kv_cache_configs: list[KVCacheConfig]) -> None:
super().initialize_from_config(kv_cache_configs)
# This executor actually can only run 1 batch at a time
self.semaphore = threading.Semaphore(1)
# Create a thread pool with a single worker
self.thread_pool = ThreadPoolExecutor(max_workers=1)
def execute_model(
self,
scheduler_output,
) -> Future[ModelRunnerOutput]:
"""Make execute_model non-blocking."""
future: Future[ModelRunnerOutput] = Future()
def _thread_wrapper(scheduler_output, future):
with self.semaphore:
output = self.collective_rpc("execute_model",
args=(scheduler_output, ))
# Make a copy because output[0] may be reused
# by the next batch.
output = copy.deepcopy(output[0])
future.set_result(output)
def _execute():
output = self.collective_rpc("execute_model",
args=(scheduler_output, ))
# Make a copy because output[0] may be reused
# by the next batch.
return copy.deepcopy(output[0])
threading.Thread(target=_thread_wrapper,
args=(scheduler_output, future)).start()
return future
# Use the thread pool instead of creating a new thread
return self.thread_pool.submit(_execute)
@property
def max_concurrent_batches(self) -> int:
return 2
def shutdown(self):
if hasattr(self, 'thread_pool'):
self.thread_pool.shutdown(wait=False)
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
@@ -299,14 +298,77 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
# Schedule Batch 1: (10, req0)
assert engine_core.step_with_batch_queue() is None
assert engine_core.batch_queue.qsize() == 1
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert scheduler_output.num_scheduled_tokens[0] == 10
# num_computed_tokens should have been updated immediately.
assert engine_core.scheduler.requests[
req0.request_id].num_computed_tokens == 10
# Schedule Batch 2: (2, req0), (8, req1)
assert engine_core.step_with_batch_queue() is None
assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert scheduler_output.num_scheduled_tokens[0] == 2
assert scheduler_output.num_scheduled_tokens[1] == 8
# num_computed_tokens should have been updated immediately.
assert engine_core.scheduler.requests[0].num_computed_tokens == 12
assert engine_core.scheduler.requests[1].num_computed_tokens == 8
assert engine_core.scheduler.get_num_unfinished_requests() == 2
# Loop through both requests.
while engine_core.scheduler.get_num_unfinished_requests() == 2:
engine_core.step_with_batch_queue()
# Batch queue is full. Finish Batch 1.
engine_core.step_with_batch_queue()
# Reaching here when got the result of the first request.
while engine_core.scheduler.get_num_unfinished_requests() == 1:
engine_core.step_with_batch_queue()
# Schedule Batch 3: (4, req1). Note that req0 cannot be scheduled
# because it is in the decoding stage now.
engine_core.step_with_batch_queue()
assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert scheduler_output.num_scheduled_tokens[1] == 4
# Batch queue is full. Finish Batch 2. Get first token of req0.
output = engine_core.step_with_batch_queue()
assert output is not None
assert len(output.outputs) == 1
assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13
# Schedule Batch 4: (1, req0).
engine_core.step_with_batch_queue()
assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert scheduler_output.num_scheduled_tokens[0] == 1
# Batch queue is full. Finish Batch 3. Get first token of req1.
output = engine_core.step_with_batch_queue()
assert output is not None
assert len(output.outputs) == 1
assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13
# Schedule Batch 5: (1, req1).
engine_core.step_with_batch_queue()
assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert scheduler_output.num_scheduled_tokens[1] == 1
# Loop until req0 is finished.
step = 0
req_id = 0
expected_num_tokens = [
engine_core.scheduler.requests[0].num_tokens + 1,
engine_core.scheduler.requests[1].num_tokens + 1,
]
while engine_core.scheduler.get_num_unfinished_requests() == 2:
output = engine_core.step_with_batch_queue()
if step % 2 == 0:
# Even steps consumes an output.
assert output is not None
assert len(output.outputs) == 1
if req_id in engine_core.scheduler.requests:
assert engine_core.scheduler.requests[
req_id].num_tokens == expected_num_tokens[req_id]
expected_num_tokens[req_id] += 1
req_id = (req_id + 1) % 2
else:
# Odd steps schedules a new batch.
assert output is None
step += 1