[Bugfix] Fix disagg hang caused by the prefill and decode communication issues (#12723)
Signed-off-by: Lu Fang <lufang@fb.com>
This commit is contained in:
@@ -10,7 +10,6 @@
|
|||||||
stop the prefill instance when the decode instance is slow.
|
stop the prefill instance when the decode instance is slow.
|
||||||
"""
|
"""
|
||||||
import threading
|
import threading
|
||||||
import time
|
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Deque, List, Optional, Union
|
from typing import Deque, List, Optional, Union
|
||||||
|
|
||||||
@@ -43,7 +42,7 @@ class SimpleBuffer(KVLookupBufferBase):
|
|||||||
|
|
||||||
self.buffer_size = 0
|
self.buffer_size = 0
|
||||||
self.buffer_size_threshold = buffer_size_thresh
|
self.buffer_size_threshold = buffer_size_thresh
|
||||||
self.buffer_lock = threading.Lock()
|
self.buffer_cv = threading.Condition()
|
||||||
self.signal_pipe = signal_pipe
|
self.signal_pipe = signal_pipe
|
||||||
self.data_pipe = data_pipe
|
self.data_pipe = data_pipe
|
||||||
self.request_handling_thread: Optional[threading.Thread] = None
|
self.request_handling_thread: Optional[threading.Thread] = None
|
||||||
@@ -116,11 +115,19 @@ class SimpleBuffer(KVLookupBufferBase):
|
|||||||
hidden = hidden.clone()
|
hidden = hidden.clone()
|
||||||
|
|
||||||
buffer_item = [input_tokens, roi, key, value, hidden]
|
buffer_item = [input_tokens, roi, key, value, hidden]
|
||||||
|
data_size = sum([self._get_element_size(data) for data in buffer_item])
|
||||||
|
|
||||||
with self.buffer_lock:
|
with self.buffer_cv:
|
||||||
for data in buffer_item:
|
if self.buffer_size + data_size > self.buffer_size_threshold:
|
||||||
self.buffer_size += self._get_element_size(data)
|
# log outside the while loop to avoid this message being logged
|
||||||
|
# repeatedly.
|
||||||
|
logger.debug("KV transfer buffer is full. Handling...")
|
||||||
|
while self.buffer_size + data_size > self.buffer_size_threshold:
|
||||||
|
self.buffer_cv.wait()
|
||||||
|
|
||||||
|
self.buffer_size += data_size
|
||||||
self.buffer.append(buffer_item)
|
self.buffer.append(buffer_item)
|
||||||
|
self.buffer_cv.notify()
|
||||||
|
|
||||||
def _is_end_signal(self, signal):
|
def _is_end_signal(self, signal):
|
||||||
return signal is None
|
return signal is None
|
||||||
@@ -143,35 +150,31 @@ class SimpleBuffer(KVLookupBufferBase):
|
|||||||
roi = (roi > 0.5)
|
roi = (roi > 0.5)
|
||||||
tokens_roi_recver = [input_tokens, roi]
|
tokens_roi_recver = [input_tokens, roi]
|
||||||
|
|
||||||
matched_length = 0
|
def is_buffer_available(
|
||||||
|
tokens_roi_recver: List[torch.Tensor], ) -> bool:
|
||||||
# perform input tokens and roi matching
|
# perform input tokens and roi matching
|
||||||
# FIXME: this matching is O(n), ideally it should be O(1)
|
# FIXME: this matching is O(n), ideally it should be O(1)
|
||||||
# but this buffer size won't (and shouldn't) be too large so
|
# but this buffer size won't (and shouldn't) be too large so
|
||||||
# the fix is not urgent.
|
# the fix is not urgent.
|
||||||
with self.buffer_lock:
|
|
||||||
|
|
||||||
for _ in range(len(self.buffer)):
|
for _ in range(len(self.buffer)):
|
||||||
|
if self._matches(self.buffer[0],
|
||||||
temp_length = self._matches(self.buffer[0],
|
tokens_roi_recver) > 0:
|
||||||
tokens_roi_recver)
|
return True
|
||||||
if temp_length > 0:
|
|
||||||
matched_length = temp_length
|
|
||||||
break
|
|
||||||
# rotate the element we just accessed to the end
|
# rotate the element we just accessed to the end
|
||||||
self.buffer.rotate(-1)
|
self.buffer.rotate(-1)
|
||||||
|
return False
|
||||||
|
|
||||||
if matched_length > 0:
|
with self.buffer_cv:
|
||||||
# need to clone the tensor
|
while not is_buffer_available(tokens_roi_recver):
|
||||||
# in case the tensor is freed before sending finishes
|
logger.debug(
|
||||||
matched_item = self.buffer.popleft()
|
"KV transfer buffer is not available. Waiting...")
|
||||||
for tensor in matched_item:
|
self.buffer_cv.wait()
|
||||||
self._send_tensor_and_dec_size(tensor)
|
# need to clone the tensor
|
||||||
|
# in case the tensor is freed before sending finishes
|
||||||
else:
|
matched_item = self.buffer.popleft()
|
||||||
# no match, just send None
|
for tensor in matched_item:
|
||||||
for _ in range(5):
|
self._send_tensor_and_dec_size(tensor)
|
||||||
self.data_pipe.send_tensor(None)
|
self.buffer_cv.notify()
|
||||||
|
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
if 'Connection closed by peer' not in str(e):
|
if 'Connection closed by peer' not in str(e):
|
||||||
@@ -208,20 +211,10 @@ class SimpleBuffer(KVLookupBufferBase):
|
|||||||
|
|
||||||
return [input_tokens, roi, key, value, hidden]
|
return [input_tokens, roi, key, value, hidden]
|
||||||
|
|
||||||
def full_handler(self):
|
|
||||||
time.sleep(0.001)
|
|
||||||
|
|
||||||
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
|
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
|
||||||
key: torch.Tensor, value: torch.Tensor,
|
key: torch.Tensor, value: torch.Tensor,
|
||||||
hidden: torch.Tensor) -> None:
|
hidden: torch.Tensor) -> None:
|
||||||
|
|
||||||
if self.buffer_size > self.buffer_size_threshold:
|
|
||||||
# log outside the while loop to avoid this message being logged
|
|
||||||
# repeatedly.
|
|
||||||
logger.debug("KV transfer buffer is full. Handling...")
|
|
||||||
while self.buffer_size > self.buffer_size_threshold:
|
|
||||||
self.full_handler()
|
|
||||||
|
|
||||||
self._add_to_buffer(input_tokens, roi, key, value, hidden)
|
self._add_to_buffer(input_tokens, roi, key, value, hidden)
|
||||||
|
|
||||||
# when calling the insert, the current process is a sender
|
# when calling the insert, the current process is a sender
|
||||||
|
|||||||
Reference in New Issue
Block a user