[ci][distributed] add pipeline parallel correctness test (#6410)
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
import asyncio
|
||||
import os
|
||||
import signal
|
||||
import weakref
|
||||
from functools import partial
|
||||
from typing import Any, List, Optional
|
||||
|
||||
@@ -78,6 +80,19 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
||||
result_handler.start()
|
||||
self.worker_monitor.start()
|
||||
|
||||
# Set up signal handlers to shutdown the executor cleanly
|
||||
# sometimes gc does not work well
|
||||
|
||||
# Use weakref to avoid holding a reference to self
|
||||
ref = weakref.ref(self)
|
||||
|
||||
def shutdown(signum, frame):
|
||||
if executor := ref():
|
||||
executor.shutdown()
|
||||
|
||||
signal.signal(signal.SIGINT, shutdown)
|
||||
signal.signal(signal.SIGTERM, shutdown)
|
||||
|
||||
self.driver_worker = self._create_worker(
|
||||
distributed_init_method=distributed_init_method)
|
||||
self._run_workers("init_device")
|
||||
|
||||
Reference in New Issue
Block a user