[Refactor] Use data parser for matching data items to multi-modal UUIDs (#32955)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -68,6 +68,16 @@ def phi3v_model_config_image_embeds():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def qwen25omni_model_config_image_embeds():
|
||||
return ModelConfig(
|
||||
QWEN25OMNI_MODEL_ID,
|
||||
runner="generate",
|
||||
limit_mm_per_prompt={"image": 2},
|
||||
enable_mm_embeds=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def qwen2_audio_model_config():
|
||||
return ModelConfig(
|
||||
@@ -823,7 +833,8 @@ def test_parse_chat_messages_audio_embeds_with_string(
|
||||
import torch
|
||||
|
||||
# Create a sample audio embedding tensor
|
||||
audio_embedding = torch.randn(1, 128, 768)
|
||||
hidden_size = audio_embeds_model_config.get_inputs_embeds_size()
|
||||
audio_embedding = torch.randn(1, 128, hidden_size)
|
||||
|
||||
# Encode it as base64
|
||||
base64_audio_embedding = tensor2base64(audio_embedding)
|
||||
@@ -865,7 +876,8 @@ async def test_parse_chat_messages_audio_embeds_async(
|
||||
import torch
|
||||
|
||||
# Create a sample audio embedding tensor
|
||||
audio_embedding = torch.randn(1, 128, 768)
|
||||
hidden_size = audio_embeds_model_config.get_inputs_embeds_size()
|
||||
audio_embedding = torch.randn(1, 128, hidden_size)
|
||||
|
||||
# Encode it as base64
|
||||
base64_audio_embedding = tensor2base64(audio_embedding)
|
||||
@@ -908,8 +920,9 @@ def test_parse_chat_messages_multiple_image_embeds(
|
||||
can be provided in a single request, similar to regular images.
|
||||
"""
|
||||
# Create two sample image embedding tensors
|
||||
image_embedding_1 = torch.randn(256, 1024)
|
||||
image_embedding_2 = torch.randn(128, 1024)
|
||||
hidden_size = phi3v_model_config_image_embeds.get_inputs_embeds_size()
|
||||
image_embedding_1 = torch.randn(256, hidden_size)
|
||||
image_embedding_2 = torch.randn(128, hidden_size)
|
||||
|
||||
# Encode them as base64 using the convenience function
|
||||
base64_image_embedding_1 = tensor2base64(image_embedding_1)
|
||||
@@ -1022,8 +1035,9 @@ async def test_parse_chat_messages_multiple_image_embeds_async(
|
||||
This validates the AsyncMultiModalItemTracker also supports multiple embeddings.
|
||||
"""
|
||||
# Create two sample image embedding tensors
|
||||
image_embedding_1 = torch.randn(200, 768)
|
||||
image_embedding_2 = torch.randn(150, 768)
|
||||
hidden_size = phi3v_model_config_image_embeds.get_inputs_embeds_size()
|
||||
image_embedding_1 = torch.randn(200, hidden_size)
|
||||
image_embedding_2 = torch.randn(150, hidden_size)
|
||||
|
||||
# Encode them as base64 using the convenience function
|
||||
base64_image_embedding_1 = tensor2base64(image_embedding_1)
|
||||
@@ -1145,13 +1159,14 @@ def test_parse_chat_messages_empty_dict_image_embeds(
|
||||
|
||||
|
||||
def test_parse_chat_messages_multiple_dict_image_embeds(
|
||||
phi3v_model_config_image_embeds,
|
||||
qwen25omni_model_config_image_embeds,
|
||||
):
|
||||
"""Test that multiple dictionaries for image_embeds is handled without errors."""
|
||||
# Create two sample image embedding tensors
|
||||
batch_size = 2
|
||||
image_embedding_1 = torch.randn(batch_size, 256, 1024)
|
||||
image_embedding_2 = torch.randn(batch_size, 3)
|
||||
hidden_size = qwen25omni_model_config_image_embeds.get_inputs_embeds_size()
|
||||
image_embeds = torch.randn(batch_size * 220, hidden_size)
|
||||
image_grid_thw = torch.tensor([[1, 22, 40] for _ in range(batch_size)])
|
||||
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
[
|
||||
@@ -1161,18 +1176,20 @@ def test_parse_chat_messages_multiple_dict_image_embeds(
|
||||
{
|
||||
"type": "image_embeds",
|
||||
"image_embeds": {
|
||||
"image_embedding_1": tensor2base64(p),
|
||||
"image_embedding_2": tensor2base64(i),
|
||||
"image_embeds": tensor2base64(embeds),
|
||||
"image_grid_thw": tensor2base64(grid_thw),
|
||||
},
|
||||
}
|
||||
for p, i in zip(image_embedding_1, image_embedding_2)
|
||||
for embeds, grid_thw in zip(
|
||||
image_embeds.chunk(batch_size), image_grid_thw
|
||||
)
|
||||
]
|
||||
+ [
|
||||
{"type": "text", "text": "Describe these two images."},
|
||||
],
|
||||
}
|
||||
],
|
||||
phi3v_model_config_image_embeds,
|
||||
qwen25omni_model_config_image_embeds,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
@@ -1180,7 +1197,8 @@ def test_parse_chat_messages_multiple_dict_image_embeds(
|
||||
assert conversation == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "<|image_1|>\n<|image_2|>\nDescribe these two images.",
|
||||
"content": "<|vision_start|><|IMAGE|><|vision_end|>\n"
|
||||
"<|vision_start|><|IMAGE|><|vision_end|>\nDescribe these two images.",
|
||||
}
|
||||
]
|
||||
|
||||
@@ -1191,10 +1209,10 @@ def test_parse_chat_messages_multiple_dict_image_embeds(
|
||||
assert len(mm_data["image"]) == batch_size
|
||||
|
||||
# Verify each embedding has the correct shape
|
||||
assert isinstance(mm_data["image"]["image_embedding_1"], torch.Tensor)
|
||||
assert mm_data["image"]["image_embedding_1"].shape == image_embedding_1.shape
|
||||
assert isinstance(mm_data["image"]["image_embedding_2"], torch.Tensor)
|
||||
assert mm_data["image"]["image_embedding_2"].shape == image_embedding_2.shape
|
||||
assert isinstance(mm_data["image"]["image_embeds"], torch.Tensor)
|
||||
assert mm_data["image"]["image_embeds"].shape == image_embeds.shape
|
||||
assert isinstance(mm_data["image"]["image_grid_thw"], torch.Tensor)
|
||||
assert mm_data["image"]["image_grid_thw"].shape == image_grid_thw.shape
|
||||
|
||||
# Verify UUIDs (None since we didn't provide any)
|
||||
_assert_mm_uuids(mm_uuids, batch_size, expected_uuids=[None, None])
|
||||
|
||||
Reference in New Issue
Block a user