[V1] Use msgpack for core request serialization (#12918)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-02-09 19:35:56 -08:00
committed by GitHub
parent aa0ca5ebb7
commit 67c4637ccf
4 changed files with 62 additions and 95 deletions

View File

@@ -1,12 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
import pickle
import queue
import signal
import threading
import time
from multiprocessing.connection import Connection
from typing import List, Tuple, Type
from typing import Any, List, Tuple, Type
import psutil
import zmq
@@ -19,13 +18,12 @@ from vllm.transformers_utils.config import (
from vllm.utils import get_exception_traceback, zmq_socket_ctx
from vllm.v1.core.kv_cache_utils import get_kv_cache_config
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile,
EngineCoreRequest, EngineCoreRequestType,
EngineCoreRequestUnion, EngineCoreResetPrefixCache)
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType)
from vllm.v1.engine.mm_input_mapper import MMInputMapperServer
from vllm.v1.executor.abstract import Executor
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackEncoder, PickleEncoder
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
@@ -161,7 +159,8 @@ class EngineCoreProc(EngineCore):
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
self.input_queue: queue.Queue[EngineCoreRequestUnion] = queue.Queue()
self.input_queue: queue.Queue[Tuple[EngineCoreRequestType,
Any]] = queue.Queue()
self.output_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
threading.Thread(target=self.process_input_socket,
args=(input_path, ),
@@ -223,7 +222,7 @@ class EngineCoreProc(EngineCore):
while True:
try:
req = self.input_queue.get(timeout=POLLING_TIMEOUT_S)
self._handle_client_request(req)
self._handle_client_request(*req)
break
except queue.Empty:
logger.debug("EngineCore busy loop waiting.")
@@ -233,10 +232,10 @@ class EngineCoreProc(EngineCore):
except BaseException:
raise
# 2) Handle any new client requests (Abort or Add).
# 2) Handle any new client requests.
while not self.input_queue.empty():
req = self.input_queue.get_nowait()
self._handle_client_request(req)
self._handle_client_request(*req)
# 3) Step the engine core.
outputs = self.step()
@@ -244,48 +243,40 @@ class EngineCoreProc(EngineCore):
# 5) Put EngineCoreOutputs into the output queue.
self.output_queue.put_nowait(outputs)
def _handle_client_request(self, request: EngineCoreRequestUnion) -> None:
"""Handle EngineCoreRequest or EngineCoreABORT from Client."""
def _handle_client_request(self, request_type: EngineCoreRequestType,
request: Any) -> None:
"""Dispatch request from client."""
if isinstance(request, EngineCoreRequest):
if request_type == EngineCoreRequestType.ADD:
self.add_request(request)
elif isinstance(request, EngineCoreProfile):
self.model_executor.profile(request.is_start)
elif isinstance(request, EngineCoreResetPrefixCache):
self.reset_prefix_cache()
else:
# TODO: make an EngineCoreAbort wrapper
assert isinstance(request, list)
elif request_type == EngineCoreRequestType.ABORT:
self.abort_requests(request)
elif request_type == EngineCoreRequestType.RESET_PREFIX_CACHE:
self.reset_prefix_cache()
elif request_type == EngineCoreRequestType.PROFILE:
self.model_executor.profile(request)
def process_input_socket(self, input_path: str):
"""Input socket IO thread."""
# Msgpack serialization decoding.
decoder_add_req = PickleEncoder()
decoder_abort_req = PickleEncoder()
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
generic_decoder = MsgpackDecoder()
with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
while True:
# (RequestType, RequestData)
type_frame, data_frame = socket.recv_multipart(copy=False)
request_type = type_frame.buffer
request_data = data_frame.buffer
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
# Deserialize the request data.
if request_type == EngineCoreRequestType.ADD.value:
request = decoder_add_req.decode(request_data)
elif request_type == EngineCoreRequestType.ABORT.value:
request = decoder_abort_req.decode(request_data)
elif request_type in (
EngineCoreRequestType.PROFILE.value,
EngineCoreRequestType.RESET_PREFIX_CACHE.value):
request = pickle.loads(request_data)
else:
raise ValueError(f"Unknown RequestType: {request_type}")
decoder = add_request_decoder if (
request_type
== EngineCoreRequestType.ADD) else generic_decoder
request = decoder.decode(data_frame.buffer)
# Push to input queue for core busy loop.
self.input_queue.put_nowait(request)
self.input_queue.put_nowait((request_type, request))
def process_output_socket(self, output_path: str):
"""Output socket IO thread."""