[Refactor] Define MultiModalKwargsItems separate from MultiModalKwargs (#23053)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -11,7 +11,8 @@ import torch
|
||||
|
||||
from vllm.multimodal.inputs import (MultiModalBatchedField,
|
||||
MultiModalFieldElem, MultiModalFlatField,
|
||||
MultiModalKwargs, MultiModalKwargsItem,
|
||||
MultiModalKwargsItem,
|
||||
MultiModalKwargsItems,
|
||||
MultiModalSharedField, NestedTensors)
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||
|
||||
@@ -96,7 +97,7 @@ def test_encode_decode(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
|
||||
class MyRequest(msgspec.Struct):
|
||||
mm: Optional[list[MultiModalKwargs]]
|
||||
mm: Optional[list[MultiModalKwargsItems]]
|
||||
|
||||
|
||||
def test_multimodal_kwargs():
|
||||
@@ -119,7 +120,7 @@ def test_multimodal_kwargs():
|
||||
audio = MultiModalKwargsItem.from_elems([e1])
|
||||
video = MultiModalKwargsItem.from_elems([e2])
|
||||
image = MultiModalKwargsItem.from_elems([e3, e4])
|
||||
mm = MultiModalKwargs([audio, video, image])
|
||||
mm = MultiModalKwargsItems.from_seq([audio, video, image])
|
||||
|
||||
# pack mm kwargs into a mock request so that it can be decoded properly
|
||||
req = MyRequest([mm])
|
||||
@@ -133,19 +134,22 @@ def test_multimodal_kwargs():
|
||||
|
||||
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
|
||||
|
||||
# expected total encoding length, should be 14255, +-20 for minor changes
|
||||
assert 14250 <= total_len <= 14300
|
||||
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
|
||||
# expected total encoding length, should be 14306, +-20 for minor changes
|
||||
assert 14275 <= total_len <= 14325
|
||||
decoded = decoder.decode(encoded).mm[0]
|
||||
assert isinstance(decoded, MultiModalKwargsItems)
|
||||
|
||||
# check all modalities were recovered and do some basic sanity checks
|
||||
assert len(decoded.modalities) == 3
|
||||
images = decoded.get_items("image")
|
||||
assert len(decoded) == 3
|
||||
images = decoded["image"]
|
||||
assert len(images) == 1
|
||||
assert len(images[0].items()) == 2
|
||||
assert list(images[0].keys()) == ["i0", "i1"]
|
||||
|
||||
# check the tensor contents and layout in the main dict
|
||||
assert all(nested_equal(mm[k], decoded[k]) for k in mm)
|
||||
mm_data = mm.get_data()
|
||||
decoded_data = decoded.get_data()
|
||||
assert all(nested_equal(mm_data[k], decoded_data[k]) for k in mm_data)
|
||||
|
||||
|
||||
def nested_equal(a: NestedTensors, b: NestedTensors):
|
||||
|
||||
Reference in New Issue
Block a user