[BugFix] Handle non-contiguous tensors properly when serializing (#16492)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-04-11 17:54:06 -07:00
committed by GitHub
parent 57504a4bcf
commit 41cc883c29
2 changed files with 30 additions and 11 deletions

View File

@@ -14,9 +14,10 @@ from msgspec import msgpack
CUSTOM_TYPE_PICKLE = 1
CUSTOM_TYPE_CLOUDPICKLE = 2
CUSTOM_TYPE_RAW_VIEW = 3
# TODO calibrate this size
INLINE_BUF_SIZE_THRESHOLD = 256
MIN_NOCOPY_BUF_SIZE = 512
bytestr = Union[bytes, bytearray, memoryview, zmq.Frame]
@@ -76,14 +77,16 @@ class MsgpackEncoder:
self, obj: np.ndarray
) -> tuple[str, tuple[int, ...], Union[int, memoryview]]:
assert self.aux_buffers is not None
if not obj.shape or obj.nbytes < INLINE_BUF_SIZE_THRESHOLD:
# Encode small arrays and scalars inline.
data = obj.data
arr_data = obj.data if obj.data.c_contiguous else obj.tobytes()
if not obj.shape or obj.nbytes < MIN_NOCOPY_BUF_SIZE:
# Encode small arrays and scalars inline. Using this extension type
# ensures we can avoid copying when decoding.
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
else:
# Otherwise encode index of backing buffer.
obj = np.ascontiguousarray(obj)
# Otherwise encode index of backing buffer to avoid copy.
data = len(self.aux_buffers)
self.aux_buffers.append(obj.data)
self.aux_buffers.append(arr_data)
# We serialize the ndarray as a tuple of native types.
# The data is either inlined if small, or an index into a list of
# backing buffers that we've stashed in `aux_buffers`.
@@ -131,6 +134,8 @@ class MsgpackDecoder:
return np.ndarray(buffer=buffer, dtype=np.dtype(dtype), shape=shape)
def ext_hook(self, code: int, data: memoryview) -> Any:
if code == CUSTOM_TYPE_RAW_VIEW:
return data
if code == CUSTOM_TYPE_PICKLE:
return pickle.loads(data)
if code == CUSTOM_TYPE_CLOUDPICKLE: