[BugFix] Async scheduling and PP compatibility with DP (#23770)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -81,12 +81,10 @@ class Executor(ExecutorBase):
|
||||
pass
|
||||
|
||||
def determine_available_memory(self) -> list[int]: # in bytes
|
||||
output = self.collective_rpc("determine_available_memory")
|
||||
return output
|
||||
return self.collective_rpc("determine_available_memory")
|
||||
|
||||
def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]:
|
||||
output = self.collective_rpc("get_kv_cache_spec")
|
||||
return output
|
||||
return self.collective_rpc("get_kv_cache_spec")
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
@@ -96,6 +94,9 @@ class Executor(ExecutorBase):
|
||||
args=(scheduler_output, ))
|
||||
return output[0]
|
||||
|
||||
def execute_dummy_batch(self) -> None:
|
||||
self.collective_rpc("execute_dummy_batch")
|
||||
|
||||
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
||||
output = self.collective_rpc("take_draft_token_ids")
|
||||
return output[0]
|
||||
|
||||
@@ -191,6 +191,10 @@ class MultiprocExecutor(Executor):
|
||||
outputs, self.output_rank)
|
||||
return self.kv_output_aggregator.aggregate(outputs, self.output_rank)
|
||||
|
||||
def execute_dummy_batch(self) -> None:
|
||||
self.collective_rpc("execute_dummy_batch",
|
||||
unique_reply_rank=self.output_rank)
|
||||
|
||||
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
||||
# OPTIMIZATION: Get output only from a single worker (output_rank)
|
||||
outputs = self.collective_rpc("take_draft_token_ids",
|
||||
@@ -242,12 +246,17 @@ class MultiprocExecutor(Executor):
|
||||
dequeue_timeout = None if deadline is None else (
|
||||
deadline - time.monotonic())
|
||||
|
||||
if non_block:
|
||||
if self.io_thread_pool is not None:
|
||||
# We must consume worker_response_mq from a single thread.
|
||||
result = self.io_thread_pool.submit( # type: ignore
|
||||
get_response, w, dequeue_timeout, self.shutdown_event)
|
||||
else:
|
||||
if not non_block:
|
||||
result = result.result()
|
||||
elif not non_block:
|
||||
result = get_response(w, dequeue_timeout)
|
||||
|
||||
else:
|
||||
raise RuntimeError("non_block can only be used when"
|
||||
" max_concurrent_batches > 1")
|
||||
responses.append(result)
|
||||
|
||||
return responses
|
||||
|
||||
Reference in New Issue
Block a user