[Multimodal] Simplify MM input definitions (#33331)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -36,8 +36,6 @@ pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
||||
def _dummy_elem(
|
||||
modality: str,
|
||||
key: str,
|
||||
size: int,
|
||||
*,
|
||||
rng: np.random.RandomState | None = None,
|
||||
@@ -48,21 +46,18 @@ def _dummy_elem(
|
||||
data = torch.from_numpy(rng.randint(4, size=(size,), dtype=np.int8))
|
||||
|
||||
return MultiModalFieldElem(
|
||||
modality=modality,
|
||||
key=key,
|
||||
data=data,
|
||||
field=MultiModalSharedField(batch_size=1),
|
||||
)
|
||||
|
||||
|
||||
def _dummy_item(
|
||||
modality: str,
|
||||
size_by_key: dict[str, int],
|
||||
*,
|
||||
rng: np.random.RandomState | None = None,
|
||||
):
|
||||
return MultiModalKwargsItem.from_elems(
|
||||
[_dummy_elem(modality, key, size, rng=rng) for key, size in size_by_key.items()]
|
||||
return MultiModalKwargsItem(
|
||||
{key: _dummy_elem(size, rng=rng) for key, size in size_by_key.items()}
|
||||
)
|
||||
|
||||
|
||||
@@ -71,19 +66,19 @@ def _dummy_items(
|
||||
*,
|
||||
rng: np.random.RandomState | None = None,
|
||||
):
|
||||
return MultiModalKwargsItems.from_seq(
|
||||
[
|
||||
_dummy_item(modality, size_by_key, rng=rng)
|
||||
return MultiModalKwargsItems(
|
||||
{
|
||||
modality: [_dummy_item(size_by_key, rng=rng)]
|
||||
for modality, size_by_key in size_by_key_modality.items()
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("item", "expected_size"),
|
||||
[
|
||||
(_dummy_item("a", {"a1": 100}), 100),
|
||||
(_dummy_item("a", {"a1": 100, "a2": 110}), 210),
|
||||
(_dummy_item({"a1": 100}), 100),
|
||||
(_dummy_item({"a1": 100, "a2": 110}), 210),
|
||||
(_dummy_items({"a": {"a1": 100, "a2": 110}, "b": {"b1": 120, "b2": 130}}), 460), # noqa: E501
|
||||
],
|
||||
)
|
||||
@@ -143,7 +138,7 @@ def _compare_caches(
|
||||
|
||||
rng = np.random.RandomState(seed)
|
||||
all_items = [
|
||||
_dummy_item("item", {"key": item_size_gb}, rng=rng)
|
||||
_dummy_item({"key": item_size_gb}, rng=rng)
|
||||
for _ in range(int(item_capacity / hit_rate))
|
||||
]
|
||||
all_hashes = [
|
||||
@@ -245,13 +240,13 @@ def _run_test_cache_eviction_lru(
|
||||
"image_C",
|
||||
]
|
||||
request1_items = {
|
||||
h: MultiModalKwargsItem.dummy(h, nbytes=2 * base_item_size)
|
||||
h: MultiModalKwargsItem.dummy(nbytes=2 * base_item_size)
|
||||
for h in request1_hashes
|
||||
}
|
||||
|
||||
request2_hashes = ["image_D", "image_E", "image_A", "image_C"]
|
||||
request2_items = {
|
||||
h: MultiModalKwargsItem.dummy(h, nbytes=1 * base_item_size)
|
||||
h: MultiModalKwargsItem.dummy(nbytes=1 * base_item_size)
|
||||
for h in request2_hashes
|
||||
}
|
||||
|
||||
@@ -356,15 +351,14 @@ def _run_test_cache_eviction_shm(
|
||||
):
|
||||
request1_hashes = ["image_A", "image_B", "image_C"]
|
||||
request1_items = {
|
||||
h: MultiModalKwargsItem.dummy(h, nbytes=5 * base_item_size)
|
||||
for h in request1_hashes
|
||||
h: MultiModalKwargsItem.dummy(5 * base_item_size) for h in request1_hashes
|
||||
}
|
||||
request1_items_p0_result = []
|
||||
|
||||
request2_hashes = ["image_G", "image_A"]
|
||||
request2_items = {
|
||||
h: MultiModalKwargsItem.dummy(
|
||||
h, nbytes=(5 if h in request1_hashes else 2) * base_item_size
|
||||
(5 if h in request1_hashes else 2) * base_item_size
|
||||
)
|
||||
for h in request2_hashes
|
||||
}
|
||||
@@ -373,7 +367,7 @@ def _run_test_cache_eviction_shm(
|
||||
request3_hashes = ["image_G", "image_H", "image_I", "image_B"]
|
||||
request3_items = {
|
||||
h: MultiModalKwargsItem.dummy(
|
||||
h, nbytes=(5 if h in request1_hashes else 2) * base_item_size
|
||||
(5 if h in request1_hashes else 2) * base_item_size
|
||||
)
|
||||
for h in request3_hashes
|
||||
}
|
||||
@@ -532,7 +526,7 @@ def test_processor_cache_shared_across_loras():
|
||||
lora_a_identifier = f"12345:{base_mm_hash}"
|
||||
lora_b_identifier = f"67890:{base_mm_hash}"
|
||||
|
||||
item_data = MultiModalKwargsItem.dummy("test_image", nbytes=1024)
|
||||
item_data = MultiModalKwargsItem.dummy(1024)
|
||||
|
||||
feature_lora_a = MultiModalFeatureSpec(
|
||||
data=item_data,
|
||||
|
||||
Reference in New Issue
Block a user