[BugFix] Avoid race conditions in zero-copy tensor transmission (#17203)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -5,6 +5,7 @@ import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from concurrent.futures import Future
|
||||
from inspect import isclass, signature
|
||||
from logging import DEBUG
|
||||
@@ -527,8 +528,12 @@ class EngineCoreProc(EngineCore):
|
||||
|
||||
# Msgpack serialization encoding.
|
||||
encoder = MsgpackEncoder()
|
||||
# Reuse send buffer.
|
||||
buffer = bytearray()
|
||||
# Send buffers to reuse.
|
||||
reuse_buffers: list[bytearray] = []
|
||||
# Keep references to outputs and buffers until zmq is finished
|
||||
# with them (outputs may contain tensors/np arrays whose
|
||||
# backing buffers were extracted for zero-copy send).
|
||||
pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
|
||||
|
||||
# We must set linger to ensure the ENGINE_CORE_DEAD
|
||||
# message is sent prior to closing the socket.
|
||||
@@ -541,8 +546,22 @@ class EngineCoreProc(EngineCore):
|
||||
break
|
||||
assert not isinstance(outputs, bytes)
|
||||
outputs.engine_index = engine_index
|
||||
|
||||
# Reclaim buffers that zmq is finished with.
|
||||
while pending and pending[-1][0].done:
|
||||
reuse_buffers.append(pending.pop()[2])
|
||||
|
||||
buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
|
||||
buffers = encoder.encode_into(outputs, buffer)
|
||||
socket.send_multipart(buffers, copy=False)
|
||||
tracker = socket.send_multipart(buffers,
|
||||
copy=False,
|
||||
track=True)
|
||||
if not tracker.done:
|
||||
ref = outputs if len(buffers) > 1 else None
|
||||
pending.appendleft((tracker, ref, buffer))
|
||||
elif len(reuse_buffers) < 2:
|
||||
# Keep at most 2 buffers to reuse.
|
||||
reuse_buffers.append(buffer)
|
||||
|
||||
|
||||
class DPEngineCoreProc(EngineCoreProc):
|
||||
|
||||
Reference in New Issue
Block a user