[Core] Pipeline Parallel Support (#4412)

Signed-off-by: Muralidhar Andoorveedu <muralidhar.andoorveedu@centml.ai>
This commit is contained in:
Murali Andoorveedu
2024-07-02 10:58:08 -07:00
committed by GitHub
parent 15aba081f3
commit c5832d2ae9
82 changed files with 1096 additions and 400 deletions

View File

@@ -211,7 +211,8 @@ class _AsyncLLMEngine(LLMEngine):
"""Extension of LLMEngine to add async methods."""
async def step_async(
self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
self, virtual_engine: int
) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results.
The workers are ran asynchronously if possible.
@@ -221,7 +222,8 @@ class _AsyncLLMEngine(LLMEngine):
and updates the scheduler with the model outputs. Finally, it decodes
the sequences and returns the newly generated results.
"""
seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
seq_group_metadata_list, scheduler_outputs = self.scheduler[
virtual_engine].schedule()
if not scheduler_outputs.is_empty():
# Execute the model.
@@ -230,6 +232,7 @@ class _AsyncLLMEngine(LLMEngine):
blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
blocks_to_copy=scheduler_outputs.blocks_to_copy,
virtual_engine=virtual_engine,
num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
running_queue_size=scheduler_outputs.running_queue_size,
)
@@ -248,16 +251,12 @@ class _AsyncLLMEngine(LLMEngine):
# Tracing
self.do_tracing(scheduler_outputs)
if not request_outputs:
# Stop the execute model loop in parallel workers until there are
# more requests to process. This avoids waiting indefinitely in
# torch.distributed ops which may otherwise timeout, and unblocks
# the RPC thread in the workers so that they can process any other
# queued control plane messages, such as add/remove lora adapters.
await self.model_executor.stop_remote_worker_execution_loop_async()
return request_outputs
async def stop_remote_worker_execution_loop_async(self) -> None:
"""Stop the remote worker execution loop."""
await self.model_executor.stop_remote_worker_execution_loop_async()
async def process_model_inputs_async(
self,
request_id: str,
@@ -491,7 +490,8 @@ class AsyncLLMEngine:
# order of the arguments.
cache_config = kwargs["cache_config"]
parallel_config = kwargs["parallel_config"]
if parallel_config.tensor_parallel_size == 1:
if (parallel_config.tensor_parallel_size == 1
and parallel_config.pipeline_parallel_size == 1):
num_gpus = cache_config.gpu_memory_utilization
else:
num_gpus = 1
@@ -499,7 +499,7 @@ class AsyncLLMEngine:
self._engine_class).remote
return engine_class(*args, **kwargs)
async def engine_step(self) -> bool:
async def engine_step(self, virtual_engine: int) -> bool:
"""Kick the engine to process the waiting requests.
Returns True if there are in-progress requests."""
@@ -530,7 +530,7 @@ class AsyncLLMEngine:
if self.engine_use_ray:
request_outputs = await self.engine.step.remote() # type: ignore
else:
request_outputs = await self.engine.step_async()
request_outputs = await self.engine.step_async(virtual_engine)
# Put the outputs into the corresponding streams.
for request_output in request_outputs:
@@ -546,18 +546,65 @@ class AsyncLLMEngine:
self.engine.abort_request(request_ids)
async def run_engine_loop(self):
has_requests_in_progress = False
if self.engine_use_ray:
pipeline_parallel_size = 1 # type: ignore
else:
pipeline_parallel_size = \
self.engine.parallel_config.pipeline_parallel_size
has_requests_in_progress = [False] * pipeline_parallel_size
while True:
if not has_requests_in_progress:
if not any(has_requests_in_progress):
logger.debug("Waiting for new requests...")
# Stop the execute model loop in parallel workers until there
# are more requests to process. This avoids waiting
# indefinitely in torch.distributed ops which may otherwise
# timeout, and unblocks the RPC thread in the workers so that
# they can process any other queued control plane messages,
# such as add/remove lora adapters.
if self.engine_use_ray:
await (self.engine.stop_remote_worker_execution_loop.
remote() # type: ignore
)
else:
await self.engine.stop_remote_worker_execution_loop_async()
await self._request_tracker.wait_for_new_requests()
logger.debug("Got new requests!")
requests_in_progress = [
asyncio.create_task(self.engine_step(ve))
for ve in range(pipeline_parallel_size)
]
has_requests_in_progress = [True] * pipeline_parallel_size
# Abort if iteration takes too long due to unrecoverable errors
# (eg. NCCL timeouts).
try:
async with asyncio_timeout(ENGINE_ITERATION_TIMEOUT_S):
has_requests_in_progress = await self.engine_step()
done, _ = await asyncio.wait(
requests_in_progress,
return_when=asyncio.FIRST_COMPLETED)
for _ in range(pipeline_parallel_size):
await asyncio.sleep(0)
for task in done:
result = task.result()
virtual_engine = requests_in_progress.index(task)
if self.engine_use_ray:
has_unfinished_requests = (
await (self.engine.
has_unfinished_requests_for_virtual_engine.
remote( # type: ignore
virtual_engine)))
else:
has_unfinished_requests = (
self.engine.
has_unfinished_requests_for_virtual_engine(
virtual_engine))
if result or has_unfinished_requests:
requests_in_progress[virtual_engine] = (
asyncio.create_task(
self.engine_step(virtual_engine)))
has_requests_in_progress[virtual_engine] = True
else:
has_requests_in_progress[virtual_engine] = False
except asyncio.TimeoutError as exc:
logger.error(
"Engine iteration timed out. This should never happen!")