[V1] Use msgpack for core request serialization (#12918)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -1,12 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import pickle
|
||||
import queue
|
||||
import signal
|
||||
import threading
|
||||
import time
|
||||
from multiprocessing.connection import Connection
|
||||
from typing import List, Tuple, Type
|
||||
from typing import Any, List, Tuple, Type
|
||||
|
||||
import psutil
|
||||
import zmq
|
||||
@@ -19,13 +18,12 @@ from vllm.transformers_utils.config import (
|
||||
from vllm.utils import get_exception_traceback, zmq_socket_ctx
|
||||
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
|
||||
from vllm.v1.core.scheduler import Scheduler
|
||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
|
||||
EngineCoreRequest, EngineCoreRequestType,
|
||||
EngineCoreRequestUnion, EngineCoreResetPrefixCache)
|
||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||
EngineCoreRequestType)
|
||||
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.serial_utils import MsgpackEncoder, PickleEncoder
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@@ -161,7 +159,8 @@ class EngineCoreProc(EngineCore):
|
||||
# and to overlap some serialization/deserialization with the
|
||||
# model forward pass.
|
||||
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
|
||||
self.input_queue: queue.Queue[EngineCoreRequestUnion] = queue.Queue()
|
||||
self.input_queue: queue.Queue[Tuple[EngineCoreRequestType,
|
||||
Any]] = queue.Queue()
|
||||
self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
|
||||
threading.Thread(target=self.process_input_socket,
|
||||
args=(input_path, ),
|
||||
@@ -223,7 +222,7 @@ class EngineCoreProc(EngineCore):
|
||||
while True:
|
||||
try:
|
||||
req = self.input_queue.get(timeout=POLLING_TIMEOUT_S)
|
||||
self._handle_client_request(req)
|
||||
self._handle_client_request(*req)
|
||||
break
|
||||
except queue.Empty:
|
||||
logger.debug("EngineCore busy loop waiting.")
|
||||
@@ -233,10 +232,10 @@ class EngineCoreProc(EngineCore):
|
||||
except BaseException:
|
||||
raise
|
||||
|
||||
# 2) Handle any new client requests (Abort or Add).
|
||||
# 2) Handle any new client requests.
|
||||
while not self.input_queue.empty():
|
||||
req = self.input_queue.get_nowait()
|
||||
self._handle_client_request(req)
|
||||
self._handle_client_request(*req)
|
||||
|
||||
# 3) Step the engine core.
|
||||
outputs = self.step()
|
||||
@@ -244,48 +243,40 @@ class EngineCoreProc(EngineCore):
|
||||
# 5) Put EngineCoreOutputs into the output queue.
|
||||
self.output_queue.put_nowait(outputs)
|
||||
|
||||
def _handle_client_request(self, request: EngineCoreRequestUnion) -> None:
|
||||
"""Handle EngineCoreRequest or EngineCoreABORT from Client."""
|
||||
def _handle_client_request(self, request_type: EngineCoreRequestType,
|
||||
request: Any) -> None:
|
||||
"""Dispatch request from client."""
|
||||
|
||||
if isinstance(request, EngineCoreRequest):
|
||||
if request_type == EngineCoreRequestType.ADD:
|
||||
self.add_request(request)
|
||||
elif isinstance(request, EngineCoreProfile):
|
||||
self.model_executor.profile(request.is_start)
|
||||
elif isinstance(request, EngineCoreResetPrefixCache):
|
||||
self.reset_prefix_cache()
|
||||
else:
|
||||
# TODO: make an EngineCoreAbort wrapper
|
||||
assert isinstance(request, list)
|
||||
elif request_type == EngineCoreRequestType.ABORT:
|
||||
self.abort_requests(request)
|
||||
elif request_type == EngineCoreRequestType.RESET_PREFIX_CACHE:
|
||||
self.reset_prefix_cache()
|
||||
elif request_type == EngineCoreRequestType.PROFILE:
|
||||
self.model_executor.profile(request)
|
||||
|
||||
def process_input_socket(self, input_path: str):
|
||||
"""Input socket IO thread."""
|
||||
|
||||
# Msgpack serialization decoding.
|
||||
decoder_add_req = PickleEncoder()
|
||||
decoder_abort_req = PickleEncoder()
|
||||
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
|
||||
generic_decoder = MsgpackDecoder()
|
||||
|
||||
with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
|
||||
while True:
|
||||
# (RequestType, RequestData)
|
||||
type_frame, data_frame = socket.recv_multipart(copy=False)
|
||||
request_type = type_frame.buffer
|
||||
request_data = data_frame.buffer
|
||||
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
|
||||
|
||||
# Deserialize the request data.
|
||||
if request_type == EngineCoreRequestType.ADD.value:
|
||||
request = decoder_add_req.decode(request_data)
|
||||
elif request_type == EngineCoreRequestType.ABORT.value:
|
||||
request = decoder_abort_req.decode(request_data)
|
||||
elif request_type in (
|
||||
EngineCoreRequestType.PROFILE.value,
|
||||
EngineCoreRequestType.RESET_PREFIX_CACHE.value):
|
||||
request = pickle.loads(request_data)
|
||||
else:
|
||||
raise ValueError(f"Unknown RequestType: {request_type}")
|
||||
decoder = add_request_decoder if (
|
||||
request_type
|
||||
== EngineCoreRequestType.ADD) else generic_decoder
|
||||
request = decoder.decode(data_frame.buffer)
|
||||
|
||||
# Push to input queue for core busy loop.
|
||||
self.input_queue.put_nowait(request)
|
||||
self.input_queue.put_nowait((request_type, request))
|
||||
|
||||
def process_output_socket(self, output_path: str):
|
||||
"""Output socket IO thread."""
|
||||
|
||||
Reference in New Issue
Block a user