[Refactor] Define MultiModalKwargsItems separate from MultiModalKwargs (#23053)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2025-08-18 17:52:00 +08:00
committed by GitHub
parent 5c79b0d648
commit 27e8d1ea3e
77 changed files with 431 additions and 383 deletions

View File

@@ -11,7 +11,8 @@ import torch
from vllm.multimodal.inputs import (MultiModalBatchedField,
MultiModalFieldElem, MultiModalFlatField,
MultiModalKwargs, MultiModalKwargsItem,
MultiModalKwargsItem,
MultiModalKwargsItems,
MultiModalSharedField, NestedTensors)
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
@@ -96,7 +97,7 @@ def test_encode_decode(monkeypatch: pytest.MonkeyPatch):
class MyRequest(msgspec.Struct):
mm: Optional[list[MultiModalKwargs]]
mm: Optional[list[MultiModalKwargsItems]]
def test_multimodal_kwargs():
@@ -119,7 +120,7 @@ def test_multimodal_kwargs():
audio = MultiModalKwargsItem.from_elems([e1])
video = MultiModalKwargsItem.from_elems([e2])
image = MultiModalKwargsItem.from_elems([e3, e4])
mm = MultiModalKwargs([audio, video, image])
mm = MultiModalKwargsItems.from_seq([audio, video, image])
# pack mm kwargs into a mock request so that it can be decoded properly
req = MyRequest([mm])
@@ -133,19 +134,22 @@ def test_multimodal_kwargs():
total_len = sum(memoryview(x).cast("B").nbytes for x in encoded)
# expected total encoding length, should be 14255, +-20 for minor changes
assert 14250 <= total_len <= 14300
decoded: MultiModalKwargs = decoder.decode(encoded).mm[0]
# expected total encoding length, should be 14306, +-20 for minor changes
assert 14275 <= total_len <= 14325
decoded = decoder.decode(encoded).mm[0]
assert isinstance(decoded, MultiModalKwargsItems)
# check all modalities were recovered and do some basic sanity checks
assert len(decoded.modalities) == 3
images = decoded.get_items("image")
assert len(decoded) == 3
images = decoded["image"]
assert len(images) == 1
assert len(images[0].items()) == 2
assert list(images[0].keys()) == ["i0", "i1"]
# check the tensor contents and layout in the main dict
assert all(nested_equal(mm[k], decoded[k]) for k in mm)
mm_data = mm.get_data()
decoded_data = decoded.get_data()
assert all(nested_equal(mm_data[k], decoded_data[k]) for k in mm_data)
def nested_equal(a: NestedTensors, b: NestedTensors):