Convert formatting to use ruff instead of yapf + isort (#26247)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@@ -9,18 +9,21 @@ import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.multimodal.inputs import (MultiModalBatchedField,
|
||||
MultiModalFieldElem, MultiModalFlatField,
|
||||
MultiModalKwargsItem,
|
||||
MultiModalKwargsItems,
|
||||
MultiModalSharedField, NestedTensors)
|
||||
from vllm.multimodal.inputs import (
|
||||
MultiModalBatchedField,
|
||||
MultiModalFieldElem,
|
||||
MultiModalFlatField,
|
||||
MultiModalKwargsItem,
|
||||
MultiModalKwargsItems,
|
||||
MultiModalSharedField,
|
||||
NestedTensors,
|
||||
)
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
class UnrecognizedType(UserDict):
|
||||
|
||||
def __init__(self, an_int: int):
|
||||
super().__init__()
|
||||
self.an_int = an_int
|
||||
@@ -47,10 +50,7 @@ def test_encode_decode(monkeypatch: pytest.MonkeyPatch):
|
||||
m.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
|
||||
|
||||
obj = MyType(
|
||||
tensor1=torch.randint(low=0,
|
||||
high=100,
|
||||
size=(1024, ),
|
||||
dtype=torch.int32),
|
||||
tensor1=torch.randint(low=0, high=100, size=(1024,), dtype=torch.int32),
|
||||
a_string="hello",
|
||||
list_of_tensors=[
|
||||
torch.rand((1, 10), dtype=torch.float32),
|
||||
@@ -58,8 +58,9 @@ def test_encode_decode(monkeypatch: pytest.MonkeyPatch):
|
||||
torch.tensor(1984), # test scalar too
|
||||
# Make sure to test bf16 which numpy doesn't support.
|
||||
torch.rand((3, 5, 1000), dtype=torch.bfloat16),
|
||||
torch.tensor([float("-inf"), float("inf")] * 1024,
|
||||
dtype=torch.bfloat16),
|
||||
torch.tensor(
|
||||
[float("-inf"), float("inf")] * 1024, dtype=torch.bfloat16
|
||||
),
|
||||
],
|
||||
numpy_array=np.arange(512),
|
||||
unrecognized=UnrecognizedType(33),
|
||||
@@ -103,22 +104,24 @@ class MyRequest(msgspec.Struct):
|
||||
|
||||
|
||||
def test_multimodal_kwargs():
|
||||
e1 = MultiModalFieldElem("audio", "a0",
|
||||
torch.zeros(1000, dtype=torch.bfloat16),
|
||||
MultiModalBatchedField())
|
||||
e1 = MultiModalFieldElem(
|
||||
"audio", "a0", torch.zeros(1000, dtype=torch.bfloat16), MultiModalBatchedField()
|
||||
)
|
||||
e2 = MultiModalFieldElem(
|
||||
"video",
|
||||
"v0",
|
||||
[torch.zeros(1000, dtype=torch.int8) for _ in range(4)],
|
||||
MultiModalFlatField(
|
||||
[[slice(1, 2, 3), slice(4, 5, 6)], [slice(None, 2)]], 0),
|
||||
MultiModalFlatField([[slice(1, 2, 3), slice(4, 5, 6)], [slice(None, 2)]], 0),
|
||||
)
|
||||
e3 = MultiModalFieldElem(
|
||||
"image", "i0", torch.zeros(1000, dtype=torch.int32), MultiModalSharedField(4)
|
||||
)
|
||||
e3 = MultiModalFieldElem("image", "i0", torch.zeros(1000,
|
||||
dtype=torch.int32),
|
||||
MultiModalSharedField(4))
|
||||
e4 = MultiModalFieldElem(
|
||||
"image", "i1", torch.zeros(1000, dtype=torch.int32),
|
||||
MultiModalFlatField([slice(1, 2, 3), slice(4, 5, 6)], 2))
|
||||
"image",
|
||||
"i1",
|
||||
torch.zeros(1000, dtype=torch.int32),
|
||||
MultiModalFlatField([slice(1, 2, 3), slice(4, 5, 6)], 2),
|
||||
)
|
||||
audio = MultiModalKwargsItem.from_elems([e1])
|
||||
video = MultiModalKwargsItem.from_elems([e2])
|
||||
image = MultiModalKwargsItem.from_elems([e3, e4])
|
||||
@@ -164,16 +167,14 @@ def assert_equal(obj1: MyType, obj2: MyType):
|
||||
assert torch.equal(obj1.tensor1, obj2.tensor1)
|
||||
assert obj1.a_string == obj2.a_string
|
||||
assert all(
|
||||
torch.equal(a, b)
|
||||
for a, b in zip(obj1.list_of_tensors, obj2.list_of_tensors))
|
||||
torch.equal(a, b) for a, b in zip(obj1.list_of_tensors, obj2.list_of_tensors)
|
||||
)
|
||||
assert np.array_equal(obj1.numpy_array, obj2.numpy_array)
|
||||
assert obj1.unrecognized.an_int == obj2.unrecognized.an_int
|
||||
assert torch.equal(obj1.small_f_contig_tensor, obj2.small_f_contig_tensor)
|
||||
assert torch.equal(obj1.large_f_contig_tensor, obj2.large_f_contig_tensor)
|
||||
assert torch.equal(obj1.small_non_contig_tensor,
|
||||
obj2.small_non_contig_tensor)
|
||||
assert torch.equal(obj1.large_non_contig_tensor,
|
||||
obj2.large_non_contig_tensor)
|
||||
assert torch.equal(obj1.small_non_contig_tensor, obj2.small_non_contig_tensor)
|
||||
assert torch.equal(obj1.large_non_contig_tensor, obj2.large_non_contig_tensor)
|
||||
assert torch.equal(obj1.empty_tensor, obj2.empty_tensor)
|
||||
|
||||
|
||||
@@ -210,8 +211,9 @@ def test_tensor_serialization():
|
||||
decoded = decoder.decode(encoded)
|
||||
|
||||
# Verify the decoded tensor matches the original
|
||||
assert torch.allclose(
|
||||
tensor, decoded), "Decoded tensor does not match the original tensor."
|
||||
assert torch.allclose(tensor, decoded), (
|
||||
"Decoded tensor does not match the original tensor."
|
||||
)
|
||||
|
||||
|
||||
def test_numpy_array_serialization():
|
||||
@@ -229,13 +231,12 @@ def test_numpy_array_serialization():
|
||||
decoded = decoder.decode(encoded)
|
||||
|
||||
# Verify the decoded array matches the original
|
||||
assert np.allclose(
|
||||
array,
|
||||
decoded), "Decoded numpy array does not match the original array."
|
||||
assert np.allclose(array, decoded), (
|
||||
"Decoded numpy array does not match the original array."
|
||||
)
|
||||
|
||||
|
||||
class CustomClass:
|
||||
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
@@ -244,7 +245,8 @@ class CustomClass:
|
||||
|
||||
|
||||
def test_custom_class_serialization_allowed_with_pickle(
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
):
|
||||
"""Test that serializing a custom class succeeds when allow_pickle=True."""
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
@@ -261,8 +263,7 @@ def test_custom_class_serialization_allowed_with_pickle(
|
||||
decoded = decoder.decode(encoded)
|
||||
|
||||
# Verify the decoded object matches the original
|
||||
assert obj == decoded, (
|
||||
"Decoded object does not match the original object.")
|
||||
assert obj == decoded, "Decoded object does not match the original object."
|
||||
|
||||
|
||||
def test_custom_class_serialization_disallowed_without_pickle():
|
||||
|
||||
Reference in New Issue
Block a user