2025-02-02 14:58:18 -05:00
|
|
|
# SPDX-License-Identifier: Apache-2.0
|
2025-06-03 11:20:17 -07:00
|
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
2024-06-07 11:23:32 -07:00
|
|
|
import pytest
|
2025-12-16 14:18:17 -08:00
|
|
|
import torch
|
2025-06-05 14:59:28 +02:00
|
|
|
|
2026-02-06 04:40:58 +08:00
|
|
|
from vllm.multimodal.inputs import (
|
|
|
|
|
MultiModalBatchedField,
|
|
|
|
|
MultiModalFieldElem,
|
|
|
|
|
MultiModalKwargsItem,
|
|
|
|
|
MultiModalSharedField,
|
|
|
|
|
PlaceholderRange,
|
|
|
|
|
)
|
|
|
|
|
from vllm.multimodal.utils import argsort_mm_positions, group_and_batch_mm_items
|
2025-09-12 00:44:34 +08:00
|
|
|
|
|
|
|
|
|
2025-09-26 02:23:01 +08:00
|
|
|
@pytest.mark.parametrize(
|
|
|
|
|
"case",
|
|
|
|
|
[
|
2025-08-13 22:18:07 +08:00
|
|
|
# Single modality
|
|
|
|
|
## Internally sorted
|
2025-09-26 02:23:01 +08:00
|
|
|
dict(
|
2025-01-06 11:58:16 -08:00
|
|
|
mm_positions={
|
|
|
|
|
"image": [
|
|
|
|
|
PlaceholderRange(offset=0, length=2),
|
|
|
|
|
PlaceholderRange(offset=3, length=2),
|
|
|
|
|
]
|
|
|
|
|
},
|
2025-08-13 22:18:07 +08:00
|
|
|
expected_modality_idxs=[
|
|
|
|
|
("image", 0),
|
|
|
|
|
("image", 1),
|
2025-01-06 11:58:16 -08:00
|
|
|
],
|
|
|
|
|
),
|
2025-08-13 22:18:07 +08:00
|
|
|
## Internally unsorted
|
2025-09-26 02:23:01 +08:00
|
|
|
dict(
|
2025-01-06 11:58:16 -08:00
|
|
|
mm_positions={
|
|
|
|
|
"image": [
|
2025-08-13 22:18:07 +08:00
|
|
|
PlaceholderRange(offset=3, length=2),
|
2025-01-06 11:58:16 -08:00
|
|
|
PlaceholderRange(offset=0, length=2),
|
|
|
|
|
]
|
|
|
|
|
},
|
2025-08-13 22:18:07 +08:00
|
|
|
expected_modality_idxs=[
|
|
|
|
|
("image", 1),
|
|
|
|
|
("image", 0),
|
2025-01-06 11:58:16 -08:00
|
|
|
],
|
|
|
|
|
),
|
2025-08-13 22:18:07 +08:00
|
|
|
# Two modalities
|
|
|
|
|
## Internally sorted
|
2025-09-26 02:23:01 +08:00
|
|
|
dict(
|
2025-01-06 11:58:16 -08:00
|
|
|
mm_positions={
|
|
|
|
|
"image": [
|
|
|
|
|
PlaceholderRange(offset=7, length=4),
|
|
|
|
|
PlaceholderRange(offset=11, length=5),
|
|
|
|
|
],
|
|
|
|
|
"audio": [
|
|
|
|
|
PlaceholderRange(offset=0, length=2),
|
|
|
|
|
PlaceholderRange(offset=2, length=3),
|
2025-10-05 17:18:11 +01:00
|
|
|
],
|
2025-01-06 11:58:16 -08:00
|
|
|
},
|
2025-08-13 22:18:07 +08:00
|
|
|
expected_modality_idxs=[
|
|
|
|
|
("audio", 0),
|
|
|
|
|
("audio", 1),
|
|
|
|
|
("image", 0),
|
|
|
|
|
("image", 1),
|
2025-01-06 11:58:16 -08:00
|
|
|
],
|
2025-08-13 22:18:07 +08:00
|
|
|
),
|
|
|
|
|
## Interleaved, internally sorted
|
2025-09-26 02:23:01 +08:00
|
|
|
dict(
|
2025-08-13 22:18:07 +08:00
|
|
|
mm_positions={
|
|
|
|
|
"image": [
|
|
|
|
|
PlaceholderRange(offset=0, length=4),
|
|
|
|
|
PlaceholderRange(offset=8, length=2),
|
|
|
|
|
],
|
|
|
|
|
"audio": [
|
|
|
|
|
PlaceholderRange(offset=5, length=2),
|
|
|
|
|
PlaceholderRange(offset=11, length=4),
|
2025-10-05 17:18:11 +01:00
|
|
|
],
|
2025-08-13 22:18:07 +08:00
|
|
|
},
|
|
|
|
|
expected_modality_idxs=[
|
|
|
|
|
("image", 0),
|
|
|
|
|
("audio", 0),
|
|
|
|
|
("image", 1),
|
|
|
|
|
("audio", 1),
|
2025-01-06 11:58:16 -08:00
|
|
|
],
|
|
|
|
|
),
|
2025-08-13 22:18:07 +08:00
|
|
|
## Interleaved, internally unsorted
|
2025-09-26 02:23:01 +08:00
|
|
|
dict(
|
2025-01-06 11:58:16 -08:00
|
|
|
mm_positions={
|
|
|
|
|
"image": [
|
2025-08-13 22:18:07 +08:00
|
|
|
PlaceholderRange(offset=8, length=2),
|
|
|
|
|
PlaceholderRange(offset=0, length=4),
|
2025-01-06 11:58:16 -08:00
|
|
|
],
|
|
|
|
|
"audio": [
|
2025-08-13 22:18:07 +08:00
|
|
|
PlaceholderRange(offset=11, length=4),
|
|
|
|
|
PlaceholderRange(offset=5, length=2),
|
2025-10-05 17:18:11 +01:00
|
|
|
],
|
2025-01-06 11:58:16 -08:00
|
|
|
},
|
2025-08-13 22:18:07 +08:00
|
|
|
expected_modality_idxs=[
|
|
|
|
|
("image", 1),
|
|
|
|
|
("audio", 1),
|
|
|
|
|
("image", 0),
|
|
|
|
|
("audio", 0),
|
2025-01-06 11:58:16 -08:00
|
|
|
],
|
|
|
|
|
),
|
|
|
|
|
# Three modalities
|
2025-08-13 22:18:07 +08:00
|
|
|
## Internally sorted
|
2025-09-26 02:23:01 +08:00
|
|
|
dict(
|
2025-01-06 11:58:16 -08:00
|
|
|
mm_positions={
|
|
|
|
|
"image": [
|
|
|
|
|
PlaceholderRange(offset=15, length=7),
|
|
|
|
|
PlaceholderRange(offset=22, length=8),
|
|
|
|
|
],
|
|
|
|
|
"audio": [
|
|
|
|
|
PlaceholderRange(offset=0, length=2),
|
|
|
|
|
],
|
|
|
|
|
"video": [
|
|
|
|
|
PlaceholderRange(offset=3, length=4),
|
|
|
|
|
PlaceholderRange(offset=7, length=5),
|
|
|
|
|
PlaceholderRange(offset=12, length=6),
|
2025-10-05 17:18:11 +01:00
|
|
|
],
|
2025-01-06 11:58:16 -08:00
|
|
|
},
|
2025-08-13 22:18:07 +08:00
|
|
|
expected_modality_idxs=[
|
|
|
|
|
("audio", 0),
|
|
|
|
|
("video", 0),
|
|
|
|
|
("video", 1),
|
|
|
|
|
("video", 2),
|
|
|
|
|
("image", 0),
|
|
|
|
|
("image", 1),
|
2025-03-29 06:30:09 -07:00
|
|
|
],
|
2025-01-06 11:58:16 -08:00
|
|
|
),
|
2025-08-13 22:18:07 +08:00
|
|
|
## Interleaved, internally sorted
|
2025-09-26 02:23:01 +08:00
|
|
|
dict(
|
2025-01-06 11:58:16 -08:00
|
|
|
mm_positions={
|
|
|
|
|
"image": [
|
|
|
|
|
PlaceholderRange(offset=0, length=2),
|
|
|
|
|
PlaceholderRange(offset=2, length=3),
|
|
|
|
|
PlaceholderRange(offset=20, length=4),
|
|
|
|
|
],
|
|
|
|
|
"audio": [
|
|
|
|
|
PlaceholderRange(offset=5, length=2),
|
|
|
|
|
],
|
|
|
|
|
"video": [
|
|
|
|
|
PlaceholderRange(offset=8, length=5),
|
2025-10-05 17:18:11 +01:00
|
|
|
],
|
2025-01-06 11:58:16 -08:00
|
|
|
},
|
2025-08-13 22:18:07 +08:00
|
|
|
expected_modality_idxs=[
|
|
|
|
|
("image", 0),
|
|
|
|
|
("image", 1),
|
|
|
|
|
("audio", 0),
|
|
|
|
|
("video", 0),
|
|
|
|
|
("image", 2),
|
2025-03-29 06:30:09 -07:00
|
|
|
],
|
2025-01-06 11:58:16 -08:00
|
|
|
),
|
2025-09-26 02:23:01 +08:00
|
|
|
## Interleaved, internally unsorted
|
|
|
|
|
dict(
|
2025-03-29 06:30:09 -07:00
|
|
|
mm_positions={
|
|
|
|
|
"image": [
|
|
|
|
|
PlaceholderRange(offset=0, length=2),
|
2025-08-13 22:18:07 +08:00
|
|
|
PlaceholderRange(offset=20, length=4),
|
|
|
|
|
PlaceholderRange(offset=2, length=3),
|
2025-03-29 06:30:09 -07:00
|
|
|
],
|
|
|
|
|
"audio": [
|
2025-08-13 22:18:07 +08:00
|
|
|
PlaceholderRange(offset=5, length=2),
|
2025-03-29 06:30:09 -07:00
|
|
|
],
|
|
|
|
|
"video": [
|
2025-08-13 22:18:07 +08:00
|
|
|
PlaceholderRange(offset=8, length=5),
|
2025-10-05 17:18:11 +01:00
|
|
|
],
|
2025-03-29 06:30:09 -07:00
|
|
|
},
|
2025-08-13 22:18:07 +08:00
|
|
|
expected_modality_idxs=[
|
|
|
|
|
("image", 0),
|
|
|
|
|
("image", 2),
|
|
|
|
|
("audio", 0),
|
|
|
|
|
("video", 0),
|
|
|
|
|
("image", 1),
|
2025-03-29 06:30:09 -07:00
|
|
|
],
|
|
|
|
|
),
|
2025-09-26 02:23:01 +08:00
|
|
|
],
|
|
|
|
|
)
|
|
|
|
|
def test_argsort_mm_positions(case):
|
|
|
|
|
mm_positions = case["mm_positions"]
|
|
|
|
|
expected_modality_idxs = case["expected_modality_idxs"]
|
2025-01-06 11:58:16 -08:00
|
|
|
|
2025-09-26 02:23:01 +08:00
|
|
|
modality_idxs = argsort_mm_positions(mm_positions)
|
2025-01-06 11:58:16 -08:00
|
|
|
|
2025-09-26 02:23:01 +08:00
|
|
|
assert modality_idxs == expected_modality_idxs
|
2025-09-26 21:23:52 -04:00
|
|
|
|
|
|
|
|
|
2026-02-06 04:40:58 +08:00
|
|
|
def test_group_and_batch_mm_items_split_by_fieldset():
|
|
|
|
|
elem = MultiModalFieldElem(
|
|
|
|
|
data=torch.empty(1, dtype=torch.uint8),
|
|
|
|
|
field=MultiModalBatchedField(),
|
2025-10-05 17:18:11 +01:00
|
|
|
)
|
2026-02-06 04:40:58 +08:00
|
|
|
item1 = MultiModalKwargsItem({"x": elem, "y": elem})
|
|
|
|
|
item2 = MultiModalKwargsItem({"y": elem, "x": elem})
|
|
|
|
|
item3 = MultiModalKwargsItem({"x": elem, "y": elem, "z": elem})
|
|
|
|
|
item4 = MultiModalKwargsItem({"x": elem})
|
|
|
|
|
item5 = MultiModalKwargsItem({"x": elem, "y": elem})
|
2025-09-26 21:23:52 -04:00
|
|
|
|
2026-02-06 04:40:58 +08:00
|
|
|
res = group_and_batch_mm_items([item1, item2, item3, item4, item5])
|
|
|
|
|
assert [num_items for num_items, _ in res] == [2, 1, 1, 1]
|
2025-09-26 21:23:52 -04:00
|
|
|
|
|
|
|
|
|
2026-02-06 04:40:58 +08:00
|
|
|
def test_group_and_batch_mm_items_split_by_shared_data():
|
|
|
|
|
elem1 = MultiModalFieldElem(
|
|
|
|
|
data=torch.zeros(1, dtype=torch.uint8),
|
|
|
|
|
field=MultiModalSharedField(batch_size=1),
|
|
|
|
|
)
|
|
|
|
|
elem2 = MultiModalFieldElem(
|
|
|
|
|
data=torch.zeros(2, dtype=torch.uint8),
|
|
|
|
|
field=MultiModalSharedField(batch_size=1),
|
|
|
|
|
)
|
|
|
|
|
item1 = MultiModalKwargsItem({"x": elem1})
|
|
|
|
|
item2 = MultiModalKwargsItem({"x": elem1})
|
|
|
|
|
item3 = MultiModalKwargsItem({"x": elem2})
|
|
|
|
|
item4 = MultiModalKwargsItem({"x": elem1})
|
|
|
|
|
item5 = MultiModalKwargsItem({"x": elem2})
|
|
|
|
|
|
|
|
|
|
res = group_and_batch_mm_items([item1, item2, item3, item4, item5])
|
|
|
|
|
assert [num_items for num_items, _ in res] == [2, 1, 1, 1]
|