[Multimodal] Simplify MM input definitions (#33331)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-29 21:32:04 +08:00
committed by GitHub
parent 17b17c0684
commit c6e7404cc5
17 changed files with 142 additions and 164 deletions

View File

@@ -36,8 +36,6 @@ pytestmark = pytest.mark.cpu_test
def _dummy_elem(
modality: str,
key: str,
size: int,
*,
rng: np.random.RandomState | None = None,
@@ -48,21 +46,18 @@ def _dummy_elem(
data = torch.from_numpy(rng.randint(4, size=(size,), dtype=np.int8))
return MultiModalFieldElem(
modality=modality,
key=key,
data=data,
field=MultiModalSharedField(batch_size=1),
)
def _dummy_item(
modality: str,
size_by_key: dict[str, int],
*,
rng: np.random.RandomState | None = None,
):
return MultiModalKwargsItem.from_elems(
[_dummy_elem(modality, key, size, rng=rng) for key, size in size_by_key.items()]
return MultiModalKwargsItem(
{key: _dummy_elem(size, rng=rng) for key, size in size_by_key.items()}
)
@@ -71,19 +66,19 @@ def _dummy_items(
*,
rng: np.random.RandomState | None = None,
):
return MultiModalKwargsItems.from_seq(
[
_dummy_item(modality, size_by_key, rng=rng)
return MultiModalKwargsItems(
{
modality: [_dummy_item(size_by_key, rng=rng)]
for modality, size_by_key in size_by_key_modality.items()
]
}
)
@pytest.mark.parametrize(
("item", "expected_size"),
[
(_dummy_item("a", {"a1": 100}), 100),
(_dummy_item("a", {"a1": 100, "a2": 110}), 210),
(_dummy_item({"a1": 100}), 100),
(_dummy_item({"a1": 100, "a2": 110}), 210),
(_dummy_items({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501
],
)
@@ -143,7 +138,7 @@ def _compare_caches(
rng = np.random.RandomState(seed)
all_items = [
_dummy_item("item", {"key": item_size_gb}, rng=rng)
_dummy_item({"key": item_size_gb}, rng=rng)
for _ in range(int(item_capacity / hit_rate))
]
all_hashes = [
@@ -245,13 +240,13 @@ def _run_test_cache_eviction_lru(
"image_C",
]
request1_items = {
h: MultiModalKwargsItem.dummy(h, nbytes=2 * base_item_size)
h: MultiModalKwargsItem.dummy(nbytes=2 * base_item_size)
for h in request1_hashes
}
request2_hashes = ["image_D", "image_E", "image_A", "image_C"]
request2_items = {
h: MultiModalKwargsItem.dummy(h, nbytes=1 * base_item_size)
h: MultiModalKwargsItem.dummy(nbytes=1 * base_item_size)
for h in request2_hashes
}
@@ -356,15 +351,14 @@ def _run_test_cache_eviction_shm(
):
request1_hashes = ["image_A", "image_B", "image_C"]
request1_items = {
h: MultiModalKwargsItem.dummy(h, nbytes=5 * base_item_size)
for h in request1_hashes
h: MultiModalKwargsItem.dummy(5 * base_item_size) for h in request1_hashes
}
request1_items_p0_result = []
request2_hashes = ["image_G", "image_A"]
request2_items = {
h: MultiModalKwargsItem.dummy(
h, nbytes=(5 if h in request1_hashes else 2) * base_item_size
(5 if h in request1_hashes else 2) * base_item_size
)
for h in request2_hashes
}
@@ -373,7 +367,7 @@ def _run_test_cache_eviction_shm(
request3_hashes = ["image_G", "image_H", "image_I", "image_B"]
request3_items = {
h: MultiModalKwargsItem.dummy(
h, nbytes=(5 if h in request1_hashes else 2) * base_item_size
(5 if h in request1_hashes else 2) * base_item_size
)
for h in request3_hashes
}
@@ -532,7 +526,7 @@ def test_processor_cache_shared_across_loras():
lora_a_identifier = f"12345:{base_mm_hash}"
lora_b_identifier = f"67890:{base_mm_hash}"
item_data = MultiModalKwargsItem.dummy("test_image", nbytes=1024)
item_data = MultiModalKwargsItem.dummy(1024)
feature_lora_a = MultiModalFeatureSpec(
data=item_data,