[V1][Performance] Implement custom serializaton for MultiModalKwargs [Rebased] (#16432)
Signed-off-by: Staszek Pasko <staszek@gmail.com> Signed-off-by: Nick Hill <nhill@redhat.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@@ -1,10 +1,16 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from collections import UserDict
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import msgspec
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from vllm.multimodal.inputs import (MultiModalBatchedField,
|
||||
MultiModalFieldElem, MultiModalKwargs,
|
||||
MultiModalKwargsItem,
|
||||
MultiModalSharedField, NestedTensors)
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||
|
||||
|
||||
@@ -50,7 +56,7 @@ def test_encode_decode():
|
||||
large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20],
|
||||
)
|
||||
|
||||
encoder = MsgpackEncoder()
|
||||
encoder = MsgpackEncoder(size_threshold=256)
|
||||
decoder = MsgpackDecoder(MyType)
|
||||
|
||||
encoded = encoder.encode(obj)
|
||||
@@ -78,6 +84,97 @@ def test_encode_decode():
|
||||
assert_equal(decoded2, obj)
|
||||
|
||||
|
||||
class MyRequest(msgspec.Struct):
|
||||
mm: Optional[list[MultiModalKwargs]]
|
||||
|
||||
|
||||
def test_multimodal_kwargs():
|
||||
d = {
|
||||
"foo":
|
||||
torch.zeros(20000, dtype=torch.float16),
|
||||
"bar": [torch.zeros(i * 1000, dtype=torch.int8) for i in range(3)],
|
||||
"baz": [
|
||||
torch.rand((256), dtype=torch.float16),
|
||||
[
|
||||
torch.rand((1, 12), dtype=torch.float32),
|
||||
torch.rand((3, 5, 7), dtype=torch.float64),
|
||||
], [torch.rand((4, 4), dtype=torch.float16)]
|
||||
],
|
||||
}
|
||||
|
||||
# pack mm kwargs into a mock request so that it can be decoded properly
|
||||
req = MyRequest(mm=[MultiModalKwargs(d)])
|
||||
|
||||
encoder = MsgpackEncoder()
|
||||
decoder = MsgpackDecoder(MyRequest)
|
||||
|
||||
encoded = encoder.encode(req)
|
||||
|
||||
assert len(encoded) == 6
|
||||
|
||||
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
|
||||
|
||||
# expected total encoding length, should be 44536, +-20 for minor changes
|
||||
assert total_len >= 44516 and total_len <= 44556
|
||||
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
|
||||
assert all(nested_equal(d[k], decoded[k]) for k in d)
|
||||
|
||||
|
||||
def test_multimodal_items_by_modality():
|
||||
e1 = MultiModalFieldElem("audio", "a0", torch.zeros(1000,
|
||||
dtype=torch.int16),
|
||||
MultiModalBatchedField())
|
||||
e2 = MultiModalFieldElem(
|
||||
"video",
|
||||
"v0",
|
||||
[torch.zeros(1000, dtype=torch.int8) for _ in range(4)],
|
||||
MultiModalBatchedField(),
|
||||
)
|
||||
e3 = MultiModalFieldElem("image", "i0", torch.zeros(1000,
|
||||
dtype=torch.int32),
|
||||
MultiModalSharedField(4))
|
||||
e4 = MultiModalFieldElem("image", "i1", torch.zeros(1000,
|
||||
dtype=torch.int32),
|
||||
MultiModalBatchedField())
|
||||
audio = MultiModalKwargsItem.from_elems([e1])
|
||||
video = MultiModalKwargsItem.from_elems([e2])
|
||||
image = MultiModalKwargsItem.from_elems([e3, e4])
|
||||
mm = MultiModalKwargs.from_items([audio, video, image])
|
||||
|
||||
# pack mm kwargs into a mock request so that it can be decoded properly
|
||||
req = MyRequest([mm])
|
||||
|
||||
encoder = MsgpackEncoder()
|
||||
decoder = MsgpackDecoder(MyRequest)
|
||||
|
||||
encoded = encoder.encode(req)
|
||||
|
||||
assert len(encoded) == 8
|
||||
|
||||
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
|
||||
|
||||
# expected total encoding length, should be 14255, +-20 for minor changes
|
||||
assert total_len >= 14235 and total_len <= 14275
|
||||
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
|
||||
|
||||
# check all modalities were recovered and do some basic sanity checks
|
||||
assert len(decoded.modalities) == 3
|
||||
images = decoded.get_items("image")
|
||||
assert len(images) == 1
|
||||
assert len(images[0].items()) == 2
|
||||
assert list(images[0].keys()) == ["i0", "i1"]
|
||||
|
||||
# check the tensor contents and layout in the main dict
|
||||
assert all(nested_equal(mm[k], decoded[k]) for k in mm)
|
||||
|
||||
|
||||
def nested_equal(a: NestedTensors, b: NestedTensors):
|
||||
if isinstance(a, torch.Tensor):
|
||||
return torch.equal(a, b)
|
||||
else:
|
||||
return all(nested_equal(x, y) for x, y in zip(a, b))
|
||||
|
||||
|
||||
def assert_equal(obj1: MyType, obj2: MyType):
|
||||
assert torch.equal(obj1.tensor1, obj2.tensor1)
|
||||
assert obj1.a_string == obj2.a_string
|
||||
|
||||
Reference in New Issue
Block a user