Files
vllm/vllm/v1/serial_utils.py
2025-04-11 17:54:06 -07:00

146 lines
5.3 KiB
Python

# SPDX-License-Identifier: Apache-2.0
import pickle
from collections.abc import Sequence
from inspect import isclass
from types import FunctionType
from typing import Any, Optional, Union
import cloudpickle
import numpy as np
import torch
import zmq
from msgspec import msgpack
CUSTOM_TYPE_PICKLE = 1
CUSTOM_TYPE_CLOUDPICKLE = 2
CUSTOM_TYPE_RAW_VIEW = 3
# TODO calibrate this size
MIN_NOCOPY_BUF_SIZE = 512
bytestr = Union[bytes, bytearray, memoryview, zmq.Frame]
class MsgpackEncoder:
"""Encoder with custom torch tensor and numpy array serialization.
Note that unlike vanilla `msgspec` Encoders, this interface is generally
not thread-safe when encoding tensors / numpy arrays.
"""
def __init__(self):
self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
# This is used as a local stash of buffers that we can then access from
# our custom `msgspec` hook, `enc_hook`. We don't have a way to
# pass custom data to the hook otherwise.
self.aux_buffers: Optional[list[bytestr]] = None
def encode(self, obj: Any) -> Sequence[bytestr]:
try:
self.aux_buffers = bufs = [b'']
bufs[0] = self.encoder.encode(obj)
# This `bufs` list allows us to collect direct pointers to backing
# buffers of tensors and np arrays, and return them along with the
# top-level encoded buffer instead of copying their data into the
# new buffer.
return bufs
finally:
self.aux_buffers = None
def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
try:
self.aux_buffers = [buf]
bufs = self.aux_buffers
self.encoder.encode_into(obj, buf)
return bufs
finally:
self.aux_buffers = None
def enc_hook(self, obj: Any) -> Any:
if isinstance(obj, torch.Tensor):
return self._encode_ndarray(obj.numpy())
# Fall back to pickle for object or void kind ndarrays.
if isinstance(obj, np.ndarray) and obj.dtype.kind not in ('O', 'V'):
return self._encode_ndarray(obj)
if isinstance(obj, FunctionType):
# `pickle` is generally faster than cloudpickle, but can have
# problems serializing methods.
return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))
return msgpack.Ext(CUSTOM_TYPE_PICKLE,
pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL))
def _encode_ndarray(
self, obj: np.ndarray
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
assert self.aux_buffers is not None
arr_data = obj.data if obj.data.c_contiguous else obj.tobytes()
if not obj.shape or obj.nbytes < MIN_NOCOPY_BUF_SIZE:
# Encode small arrays and scalars inline. Using this extension type
# ensures we can avoid copying when decoding.
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
else:
# Otherwise encode index of backing buffer to avoid copy.
data = len(self.aux_buffers)
self.aux_buffers.append(arr_data)
# We serialize the ndarray as a tuple of native types.
# The data is either inlined if small, or an index into a list of
# backing buffers that we've stashed in `aux_buffers`.
return obj.dtype.str, obj.shape, data
class MsgpackDecoder:
"""Decoder with custom torch tensor and numpy array serialization.
Note that unlike vanilla `msgspec` Decoders, this interface is generally
not thread-safe when encoding tensors / numpy arrays.
"""
def __init__(self, t: Optional[Any] = None):
args = () if t is None else (t, )
self.decoder = msgpack.Decoder(*args,
ext_hook=self.ext_hook,
dec_hook=self.dec_hook)
self.aux_buffers: Sequence[bytestr] = ()
def decode(self, bufs: Union[bytestr, Sequence[bytestr]]) -> Any:
if isinstance(bufs, (bytes, bytearray, memoryview, zmq.Frame)):
# TODO - This check can become `isinstance(bufs, bytestr)`
# as of Python 3.10.
return self.decoder.decode(bufs)
self.aux_buffers = bufs
try:
return self.decoder.decode(bufs[0])
finally:
self.aux_buffers = ()
def dec_hook(self, t: type, obj: Any) -> Any:
# Given native types in `obj`, convert to type `t`.
if isclass(t):
if issubclass(t, np.ndarray):
return self._decode_ndarray(obj)
if issubclass(t, torch.Tensor):
return torch.from_numpy(self._decode_ndarray(obj))
return obj
def _decode_ndarray(self, arr: Any) -> np.ndarray:
dtype, shape, data = arr
buffer = self.aux_buffers[data] if isinstance(data, int) else data
return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape)
def ext_hook(self, code: int, data: memoryview) -> Any:
if code == CUSTOM_TYPE_RAW_VIEW:
return data
if code == CUSTOM_TYPE_PICKLE:
return pickle.loads(data)
if code == CUSTOM_TYPE_CLOUDPICKLE:
return cloudpickle.loads(data)
raise NotImplementedError(
f"Extension type code {code} is not supported")