[V1] Allow turning off pickle fallback in vllm.v1.serial_utils (#17427)
Signed-off-by: Russell Bryant <rbryant@redhat.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
@@ -5,6 +5,7 @@ from typing import Optional
|
||||
|
||||
import msgspec
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.multimodal.inputs import (MultiModalBatchedField,
|
||||
@@ -196,3 +197,100 @@ def assert_equal(obj1: MyType, obj2: MyType):
|
||||
assert torch.equal(obj1.large_non_contig_tensor,
|
||||
obj2.large_non_contig_tensor)
|
||||
assert torch.equal(obj1.empty_tensor, obj2.empty_tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("allow_pickle", [True, False])
|
||||
def test_dict_serialization(allow_pickle: bool):
|
||||
"""Test encoding and decoding of a generic Python object using pickle."""
|
||||
encoder = MsgpackEncoder(allow_pickle=allow_pickle)
|
||||
decoder = MsgpackDecoder(allow_pickle=allow_pickle)
|
||||
|
||||
# Create a sample Python object
|
||||
obj = {"key": "value", "number": 42}
|
||||
|
||||
# Encode the object
|
||||
encoded = encoder.encode(obj)
|
||||
|
||||
# Decode the object
|
||||
decoded = decoder.decode(encoded)
|
||||
|
||||
# Verify the decoded object matches the original
|
||||
assert obj == decoded, "Decoded object does not match the original object."
|
||||
|
||||
|
||||
@pytest.mark.parametrize("allow_pickle", [True, False])
|
||||
def test_tensor_serialization(allow_pickle: bool):
|
||||
"""Test encoding and decoding of a torch.Tensor."""
|
||||
encoder = MsgpackEncoder(allow_pickle=allow_pickle)
|
||||
decoder = MsgpackDecoder(torch.Tensor, allow_pickle=allow_pickle)
|
||||
|
||||
# Create a sample tensor
|
||||
tensor = torch.rand(10, 10)
|
||||
|
||||
# Encode the tensor
|
||||
encoded = encoder.encode(tensor)
|
||||
|
||||
# Decode the tensor
|
||||
decoded = decoder.decode(encoded)
|
||||
|
||||
# Verify the decoded tensor matches the original
|
||||
assert torch.allclose(
|
||||
tensor, decoded), "Decoded tensor does not match the original tensor."
|
||||
|
||||
|
||||
@pytest.mark.parametrize("allow_pickle", [True, False])
|
||||
def test_numpy_array_serialization(allow_pickle: bool):
|
||||
"""Test encoding and decoding of a numpy array."""
|
||||
encoder = MsgpackEncoder(allow_pickle=allow_pickle)
|
||||
decoder = MsgpackDecoder(np.ndarray, allow_pickle=allow_pickle)
|
||||
|
||||
# Create a sample numpy array
|
||||
array = np.random.rand(10, 10)
|
||||
|
||||
# Encode the numpy array
|
||||
encoded = encoder.encode(array)
|
||||
|
||||
# Decode the numpy array
|
||||
decoded = decoder.decode(encoded)
|
||||
|
||||
# Verify the decoded array matches the original
|
||||
assert np.allclose(
|
||||
array,
|
||||
decoded), "Decoded numpy array does not match the original array."
|
||||
|
||||
|
||||
class CustomClass:
|
||||
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, CustomClass) and self.value == other.value
|
||||
|
||||
|
||||
def test_custom_class_serialization_allowed_with_pickle():
|
||||
"""Test that serializing a custom class succeeds when allow_pickle=True."""
|
||||
encoder = MsgpackEncoder(allow_pickle=True)
|
||||
decoder = MsgpackDecoder(CustomClass, allow_pickle=True)
|
||||
|
||||
obj = CustomClass("test_value")
|
||||
|
||||
# Encode the custom class
|
||||
encoded = encoder.encode(obj)
|
||||
|
||||
# Decode the custom class
|
||||
decoded = decoder.decode(encoded)
|
||||
|
||||
# Verify the decoded object matches the original
|
||||
assert obj == decoded, "Decoded object does not match the original object."
|
||||
|
||||
|
||||
def test_custom_class_serialization_disallowed_without_pickle():
|
||||
"""Test that serializing a custom class fails when allow_pickle=False."""
|
||||
encoder = MsgpackEncoder(allow_pickle=False)
|
||||
|
||||
obj = CustomClass("test_value")
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# Attempt to encode the custom class
|
||||
encoder.encode(obj)
|
||||
|
||||
Reference in New Issue
Block a user