[Bugfix] Dictionary MM embeddings for online chat (#30507)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
@@ -796,9 +796,13 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid(
|
||||
"content": "<|image_1|>\nWhat's in this image?",
|
||||
}
|
||||
]
|
||||
|
||||
assert mm_data is not None
|
||||
assert "image" in mm_data
|
||||
assert mm_data["image"] is None
|
||||
assert isinstance(mm_data["image"], list)
|
||||
assert len(mm_data["image"]) == 1
|
||||
assert mm_data["image"][0] is None
|
||||
|
||||
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid])
|
||||
|
||||
|
||||
@@ -825,10 +829,11 @@ def test_parse_chat_messages_empty_audio_embeds_with_uuid(
|
||||
# Should have audio in mm_data as None (UUID provided)
|
||||
assert mm_data is not None
|
||||
assert "audio" in mm_data
|
||||
assert mm_data["audio"] is None
|
||||
assert isinstance(mm_data["audio"], list)
|
||||
assert len(mm_data["audio"]) == 1
|
||||
assert mm_data["audio"][0] is None
|
||||
|
||||
# UUID should be recorded
|
||||
assert mm_uuids is not None
|
||||
assert "audio" in mm_uuids
|
||||
_assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[uuid])
|
||||
|
||||
|
||||
@@ -1121,10 +1126,105 @@ async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
|
||||
mm_data = await mm_future
|
||||
assert mm_data is not None
|
||||
assert "image" in mm_data
|
||||
assert mm_data["image"] is None
|
||||
assert isinstance(mm_data["image"], list)
|
||||
assert len(mm_data["image"]) == 1
|
||||
assert mm_data["image"][0] is None
|
||||
|
||||
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid])
|
||||
|
||||
|
||||
def test_parse_chat_messages_empty_dict_image_embeds(
|
||||
phi3v_model_config_image_embeds,
|
||||
):
|
||||
"""Test that empty dictionary for image_embeds is handled without errors."""
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_embeds", "image_embeds": {}},
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
],
|
||||
}
|
||||
],
|
||||
phi3v_model_config_image_embeds,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
# Verify conversation structure
|
||||
assert conversation == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "<|image_1|>\nWhat's in this image?",
|
||||
}
|
||||
]
|
||||
|
||||
# Verify mm_data contains an empty dictionary of embeddings
|
||||
assert mm_data is not None
|
||||
assert "image" in mm_data
|
||||
assert isinstance(mm_data["image"], dict)
|
||||
assert len(mm_data["image"]) == 0
|
||||
|
||||
# Verify UUIDs (None since we didn't provide any)
|
||||
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[None])
|
||||
|
||||
|
||||
def test_parse_chat_messages_multiple_dict_image_embeds(
|
||||
phi3v_model_config_image_embeds,
|
||||
):
|
||||
"""Test that multiple dictionaries for image_embeds is handled without errors."""
|
||||
# Create two sample image embedding tensors
|
||||
batch_size = 2
|
||||
image_embedding_1 = torch.randn(batch_size, 256, 1024)
|
||||
image_embedding_2 = torch.randn(batch_size, 3)
|
||||
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_embeds",
|
||||
"image_embeds": {
|
||||
"image_embedding_1": tensor2base64(p),
|
||||
"image_embedding_2": tensor2base64(i),
|
||||
},
|
||||
}
|
||||
for p, i in zip(image_embedding_1, image_embedding_2)
|
||||
]
|
||||
+ [
|
||||
{"type": "text", "text": "Describe these two images."},
|
||||
],
|
||||
}
|
||||
],
|
||||
phi3v_model_config_image_embeds,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
# Verify conversation structure
|
||||
assert conversation == [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "<|image_1|>\n<|image_2|>\nDescribe these two images.",
|
||||
}
|
||||
]
|
||||
|
||||
# Verify mm_data contains a dictionary of multi-embeddings
|
||||
assert mm_data is not None
|
||||
assert "image" in mm_data
|
||||
assert isinstance(mm_data["image"], dict)
|
||||
assert len(mm_data["image"]) == batch_size
|
||||
|
||||
# Verify each embedding has the correct shape
|
||||
assert isinstance(mm_data["image"]["image_embedding_1"], torch.Tensor)
|
||||
assert mm_data["image"]["image_embedding_1"].shape == image_embedding_1.shape
|
||||
assert isinstance(mm_data["image"]["image_embedding_2"], torch.Tensor)
|
||||
assert mm_data["image"]["image_embedding_2"].shape == image_embedding_2.shape
|
||||
|
||||
# Verify UUIDs (None since we didn't provide any)
|
||||
_assert_mm_uuids(mm_uuids, batch_size, expected_uuids=[None, None])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_multiple_images_async(
|
||||
phi3v_model_config,
|
||||
|
||||
Reference in New Issue
Block a user