[ci][distributed] add pipeline parallel correctness test (#6410)

This commit is contained in:
youkaichao
2024-07-16 15:44:22 -07:00
committed by GitHub
parent 978aed5300
commit 09c2eb85dd
3 changed files with 119 additions and 118 deletions

View File

@@ -1,5 +1,7 @@
import asyncio
import os
import signal
import weakref
from functools import partial
from typing import Any, List, Optional
@@ -78,6 +80,19 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
result_handler.start()
self.worker_monitor.start()
# Set up signal handlers to shutdown the executor cleanly
# sometimes gc does not work well
# Use weakref to avoid holding a reference to self
ref = weakref.ref(self)
def shutdown(signum, frame):
if executor := ref():
executor.shutdown()
signal.signal(signal.SIGINT, shutdown)
signal.signal(signal.SIGTERM, shutdown)
self.driver_worker = self._create_worker(
distributed_init_method=distributed_init_method)
self._run_workers("init_device")