[Refactor] Use data parser for matching data items to multi-modal UUIDs (#32955)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung
2026-01-26 15:00:28 +08:00
committed by GitHub
parent ee484b3f4b
commit 11b556878b
14 changed files with 701 additions and 604 deletions

View File

@@ -68,6 +68,16 @@ def phi3v_model_config_image_embeds():
)
@pytest.fixture(scope="function")
def qwen25omni_model_config_image_embeds():
return ModelConfig(
QWEN25OMNI_MODEL_ID,
runner="generate",
limit_mm_per_prompt={"image": 2},
enable_mm_embeds=True,
)
@pytest.fixture(scope="function")
def qwen2_audio_model_config():
return ModelConfig(
@@ -823,7 +833,8 @@ def test_parse_chat_messages_audio_embeds_with_string(
import torch
# Create a sample audio embedding tensor
audio_embedding = torch.randn(1, 128, 768)
hidden_size = audio_embeds_model_config.get_inputs_embeds_size()
audio_embedding = torch.randn(1, 128, hidden_size)
# Encode it as base64
base64_audio_embedding = tensor2base64(audio_embedding)
@@ -865,7 +876,8 @@ async def test_parse_chat_messages_audio_embeds_async(
import torch
# Create a sample audio embedding tensor
audio_embedding = torch.randn(1, 128, 768)
hidden_size = audio_embeds_model_config.get_inputs_embeds_size()
audio_embedding = torch.randn(1, 128, hidden_size)
# Encode it as base64
base64_audio_embedding = tensor2base64(audio_embedding)
@@ -908,8 +920,9 @@ def test_parse_chat_messages_multiple_image_embeds(
can be provided in a single request, similar to regular images.
"""
# Create two sample image embedding tensors
image_embedding_1 = torch.randn(256, 1024)
image_embedding_2 = torch.randn(128, 1024)
hidden_size = phi3v_model_config_image_embeds.get_inputs_embeds_size()
image_embedding_1 = torch.randn(256, hidden_size)
image_embedding_2 = torch.randn(128, hidden_size)
# Encode them as base64 using the convenience function
base64_image_embedding_1 = tensor2base64(image_embedding_1)
@@ -1022,8 +1035,9 @@ async def test_parse_chat_messages_multiple_image_embeds_async(
This validates the AsyncMultiModalItemTracker also supports multiple embeddings.
"""
# Create two sample image embedding tensors
image_embedding_1 = torch.randn(200, 768)
image_embedding_2 = torch.randn(150, 768)
hidden_size = phi3v_model_config_image_embeds.get_inputs_embeds_size()
image_embedding_1 = torch.randn(200, hidden_size)
image_embedding_2 = torch.randn(150, hidden_size)
# Encode them as base64 using the convenience function
base64_image_embedding_1 = tensor2base64(image_embedding_1)
@@ -1145,13 +1159,14 @@ def test_parse_chat_messages_empty_dict_image_embeds(
def test_parse_chat_messages_multiple_dict_image_embeds(
phi3v_model_config_image_embeds,
qwen25omni_model_config_image_embeds,
):
"""Test that multiple dictionaries for image_embeds is handled without errors."""
# Create two sample image embedding tensors
batch_size = 2
image_embedding_1 = torch.randn(batch_size, 256, 1024)
image_embedding_2 = torch.randn(batch_size, 3)
hidden_size = qwen25omni_model_config_image_embeds.get_inputs_embeds_size()
image_embeds = torch.randn(batch_size * 220, hidden_size)
image_grid_thw = torch.tensor([[1, 22, 40] for _ in range(batch_size)])
conversation, mm_data, mm_uuids = parse_chat_messages(
[
@@ -1161,18 +1176,20 @@ def test_parse_chat_messages_multiple_dict_image_embeds(
{
"type": "image_embeds",
"image_embeds": {
"image_embedding_1": tensor2base64(p),
"image_embedding_2": tensor2base64(i),
"image_embeds": tensor2base64(embeds),
"image_grid_thw": tensor2base64(grid_thw),
},
}
for p, i in zip(image_embedding_1, image_embedding_2)
for embeds, grid_thw in zip(
image_embeds.chunk(batch_size), image_grid_thw
)
]
+ [
{"type": "text", "text": "Describe these two images."},
],
}
],
phi3v_model_config_image_embeds,
qwen25omni_model_config_image_embeds,
content_format="string",
)
@@ -1180,7 +1197,8 @@ def test_parse_chat_messages_multiple_dict_image_embeds(
assert conversation == [
{
"role": "user",
"content": "<|image_1|>\n<|image_2|>\nDescribe these two images.",
"content": "<|vision_start|><|IMAGE|><|vision_end|>\n"
"<|vision_start|><|IMAGE|><|vision_end|>\nDescribe these two images.",
}
]
@@ -1191,10 +1209,10 @@ def test_parse_chat_messages_multiple_dict_image_embeds(
assert len(mm_data["image"]) == batch_size
# Verify each embedding has the correct shape
assert isinstance(mm_data["image"]["image_embedding_1"], torch.Tensor)
assert mm_data["image"]["image_embedding_1"].shape == image_embedding_1.shape
assert isinstance(mm_data["image"]["image_embedding_2"], torch.Tensor)
assert mm_data["image"]["image_embedding_2"].shape == image_embedding_2.shape
assert isinstance(mm_data["image"]["image_embeds"], torch.Tensor)
assert mm_data["image"]["image_embeds"].shape == image_embeds.shape
assert isinstance(mm_data["image"]["image_grid_thw"], torch.Tensor)
assert mm_data["image"]["image_grid_thw"].shape == image_grid_thw.shape
# Verify UUIDs (None since we didn't provide any)
_assert_mm_uuids(mm_uuids, batch_size, expected_uuids=[None, None])