[ Frontend ] Multiprocessing for OpenAI Server with zeromq (#6883)

Signed-off-by: Joe Runde <Joseph.Runde@ibm.com>
Co-authored-by: Joe Runde <Joseph.Runde@ibm.com>
Co-authored-by: Joe Runde <joe@joerun.de>
Co-authored-by: Nick Hill <nickhill@us.ibm.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
Robert Shaw
2024-08-02 21:27:28 -04:00
committed by GitHub
parent 708989341e
commit ed812a73fa
20 changed files with 1567 additions and 101 deletions

View File

@@ -290,6 +290,10 @@ def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
return _async_wrapper
class ProducerFinished:
pass
def merge_async_iterators(
*iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]:
"""Merge multiple asynchronous iterators into a single iterator.
@@ -298,9 +302,10 @@ def merge_async_iterators(
When it yields, it yields a tuple (i, item) where i is the index of the
iterator that yields the item.
"""
queue: asyncio.Queue[Union[Tuple[int, T], Exception]] = asyncio.Queue()
queue: asyncio.Queue[Union[Tuple[int, T], ProducerFinished,
Exception]] = asyncio.Queue()
finished = [False] * len(iterators)
producers = len(iterators)
async def producer(i: int, iterator: AsyncIterator[T]):
try:
@@ -308,7 +313,8 @@ def merge_async_iterators(
await queue.put((i, item))
except Exception as e:
await queue.put(e)
finished[i] = True
# Signal to the consumer that we've finished
await queue.put(ProducerFinished())
_tasks = [
asyncio.create_task(producer(i, iterator))
@@ -316,9 +322,17 @@ def merge_async_iterators(
]
async def consumer():
remaining = producers
try:
while not all(finished) or not queue.empty():
while remaining or not queue.empty():
# we think there is a race condition here
item = await queue.get()
if isinstance(item, ProducerFinished):
# Signal that a producer finished- not a real item
remaining -= 1
continue
if isinstance(item, Exception):
raise item
yield item
@@ -374,8 +388,10 @@ def get_distributed_init_method(ip: str, port: int) -> str:
return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
def get_open_port() -> int:
port = envs.VLLM_PORT
def get_open_port(port: Optional[int] = None) -> int:
if port is None:
# Default behavior here is to return a port for multi-gpu communication
port = envs.VLLM_PORT
if port is not None:
while True:
try: