[ Frontend ] Multiprocessing for OpenAI Server with zeromq (#6883)
Signed-off-by: Joe Runde <Joseph.Runde@ibm.com> Co-authored-by: Joe Runde <Joseph.Runde@ibm.com> Co-authored-by: Joe Runde <joe@joerun.de> Co-authored-by: Nick Hill <nickhill@us.ibm.com> Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
@@ -290,6 +290,10 @@ def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]:
|
||||
return _async_wrapper
|
||||
|
||||
|
||||
class ProducerFinished:
|
||||
pass
|
||||
|
||||
|
||||
def merge_async_iterators(
|
||||
*iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]:
|
||||
"""Merge multiple asynchronous iterators into a single iterator.
|
||||
@@ -298,9 +302,10 @@ def merge_async_iterators(
|
||||
When it yields, it yields a tuple (i, item) where i is the index of the
|
||||
iterator that yields the item.
|
||||
"""
|
||||
queue: asyncio.Queue[Union[Tuple[int, T], Exception]] = asyncio.Queue()
|
||||
queue: asyncio.Queue[Union[Tuple[int, T], ProducerFinished,
|
||||
Exception]] = asyncio.Queue()
|
||||
|
||||
finished = [False] * len(iterators)
|
||||
producers = len(iterators)
|
||||
|
||||
async def producer(i: int, iterator: AsyncIterator[T]):
|
||||
try:
|
||||
@@ -308,7 +313,8 @@ def merge_async_iterators(
|
||||
await queue.put((i, item))
|
||||
except Exception as e:
|
||||
await queue.put(e)
|
||||
finished[i] = True
|
||||
# Signal to the consumer that we've finished
|
||||
await queue.put(ProducerFinished())
|
||||
|
||||
_tasks = [
|
||||
asyncio.create_task(producer(i, iterator))
|
||||
@@ -316,9 +322,17 @@ def merge_async_iterators(
|
||||
]
|
||||
|
||||
async def consumer():
|
||||
remaining = producers
|
||||
try:
|
||||
while not all(finished) or not queue.empty():
|
||||
while remaining or not queue.empty():
|
||||
# we think there is a race condition here
|
||||
item = await queue.get()
|
||||
|
||||
if isinstance(item, ProducerFinished):
|
||||
# Signal that a producer finished- not a real item
|
||||
remaining -= 1
|
||||
continue
|
||||
|
||||
if isinstance(item, Exception):
|
||||
raise item
|
||||
yield item
|
||||
@@ -374,8 +388,10 @@ def get_distributed_init_method(ip: str, port: int) -> str:
|
||||
return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
|
||||
|
||||
|
||||
def get_open_port() -> int:
|
||||
port = envs.VLLM_PORT
|
||||
def get_open_port(port: Optional[int] = None) -> int:
|
||||
if port is None:
|
||||
# Default behavior here is to return a port for multi-gpu communication
|
||||
port = envs.VLLM_PORT
|
||||
if port is not None:
|
||||
while True:
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user