[MISC][Bugfix] Use less CPU when message queue has been empty for some time (#16226)
Signed-off-by: Povilas Kanapickas <povilas@radix.lt>
This commit is contained in:
committed by
GitHub
parent
61059bee40
commit
85e2b7bb13
@@ -28,6 +28,43 @@ VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SpinTimer:
|
||||
|
||||
def record_activity(self):
|
||||
pass
|
||||
|
||||
def spin(self):
|
||||
sched_yield()
|
||||
|
||||
|
||||
class SpinSleepTimer(SpinTimer):
|
||||
"""
|
||||
In setups which have long inactivity periods it is desirable to reduce
|
||||
system power consumption when vllm does nothing. This would lead to more
|
||||
CPU thermal headroom when a request eventually comes, especially when
|
||||
multiple GPUs are connected as each GPU would otherwise pin one thread at
|
||||
100% CPU usage.
|
||||
|
||||
The simplest solution is to reduce polling frequency when there is no
|
||||
activity for a certain period of time.
|
||||
"""
|
||||
|
||||
def __init__(self, busy_loop_s: float = 3.0, wait_sleep_s: float = 0.1):
|
||||
self.last_activity = time.monotonic()
|
||||
self.busy_loop_s = busy_loop_s
|
||||
self.wait_sleep_s = wait_sleep_s
|
||||
|
||||
def record_activity(self):
|
||||
self.last_activity = time.monotonic()
|
||||
|
||||
def spin(self):
|
||||
curr_time = time.monotonic()
|
||||
if curr_time >= self.last_activity + self.busy_loop_s:
|
||||
time.sleep(self.wait_sleep_s)
|
||||
else:
|
||||
sched_yield()
|
||||
|
||||
|
||||
class ShmRingBuffer:
|
||||
|
||||
def __init__(self,
|
||||
@@ -42,7 +79,7 @@ class ShmRingBuffer:
|
||||
of items that can be stored in the buffer are known in advance.
|
||||
In this case, we don't need to synchronize the access to
|
||||
the buffer.
|
||||
|
||||
|
||||
Buffer memory layout:
|
||||
data metadata
|
||||
| |
|
||||
@@ -238,6 +275,7 @@ class MessageQueue:
|
||||
self.local_reader_rank = -1
|
||||
# rank does not matter for remote readers
|
||||
self._is_remote_reader = False
|
||||
self._read_spin_timer = SpinTimer()
|
||||
|
||||
self.handle = Handle(
|
||||
local_reader_ranks=local_reader_ranks,
|
||||
@@ -276,6 +314,9 @@ class MessageQueue:
|
||||
self.local_socket.connect(socket_addr)
|
||||
|
||||
self.remote_socket = None
|
||||
|
||||
self._read_spin_timer = SpinSleepTimer(
|
||||
) if envs.VLLM_SLEEP_WHEN_IDLE else SpinTimer()
|
||||
else:
|
||||
self.buffer = None # type: ignore
|
||||
self.current_idx = -1
|
||||
@@ -407,7 +448,7 @@ class MessageQueue:
|
||||
# we need to wait until it is written
|
||||
|
||||
# Release the processor to other threads
|
||||
sched_yield()
|
||||
self._read_spin_timer.spin()
|
||||
|
||||
# if we wait for a long time, log a message
|
||||
if (time.monotonic() - start_time
|
||||
@@ -438,6 +479,8 @@ class MessageQueue:
|
||||
metadata_buffer[self.local_reader_rank + 1] = 1
|
||||
self.current_idx = (self.current_idx +
|
||||
1) % self.buffer.max_chunks
|
||||
|
||||
self._read_spin_timer.record_activity()
|
||||
break
|
||||
|
||||
def enqueue(self, obj, timeout: Optional[float] = None):
|
||||
|
||||
Reference in New Issue
Block a user