[Core] Add audio_embeds support to chat completions (#29059)
Signed-off-by: Jeremy Teboul <jeremyteboul@fb.com> Co-authored-by: Jeremy Teboul <jeremyteboul@fb.com>
This commit is contained in:
@@ -103,6 +103,19 @@ def qwen2_audio_model_config():
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def audio_embeds_model_config():
|
||||
return ModelConfig(
|
||||
QWEN2AUDIO_MODEL_ID,
|
||||
runner="generate",
|
||||
trust_remote_code=True,
|
||||
limit_mm_per_prompt={
|
||||
"audio": 2,
|
||||
},
|
||||
enable_mm_embeds=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def qwen2_audio_tokenizer():
|
||||
return get_tokenizer(QWEN2AUDIO_MODEL_ID)
|
||||
@@ -843,6 +856,138 @@ def test_parse_chat_messages_empty_image_embeds_with_uuid(
|
||||
_assert_mm_uuids(mm_uuids, 1, expected_uuids=[uuid])
|
||||
|
||||
|
||||
def test_parse_chat_messages_empty_audio_embeds_with_uuid(
|
||||
audio_embeds_model_config,
|
||||
qwen2_audio_tokenizer,
|
||||
):
|
||||
"""Test audio_embeds with UUID (no actual embeds data)."""
|
||||
uuid = "test-audio-uuid-123"
|
||||
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Describe this audio"},
|
||||
{"type": "audio_embeds", "audio_embeds": None, "uuid": uuid},
|
||||
],
|
||||
}
|
||||
],
|
||||
audio_embeds_model_config,
|
||||
qwen2_audio_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
# Should have audio in mm_data as None (UUID provided)
|
||||
assert mm_data is not None
|
||||
assert "audio" in mm_data
|
||||
assert mm_data["audio"] is None
|
||||
# UUID should be recorded
|
||||
assert mm_uuids is not None
|
||||
assert "audio" in mm_uuids
|
||||
_assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[uuid])
|
||||
|
||||
|
||||
def test_parse_chat_messages_audio_embeds_with_string(
|
||||
audio_embeds_model_config,
|
||||
qwen2_audio_tokenizer,
|
||||
):
|
||||
"""Test audio_embeds with base64 string embedding data."""
|
||||
import base64
|
||||
import io
|
||||
|
||||
import torch
|
||||
|
||||
# Create a sample audio embedding tensor
|
||||
audio_embedding = torch.randn(1, 128, 768)
|
||||
|
||||
# Encode it as base64
|
||||
buffer = io.BytesIO()
|
||||
torch.save(audio_embedding, buffer)
|
||||
buffer.seek(0)
|
||||
binary_data = buffer.read()
|
||||
base64_audio_embedding = base64.b64encode(binary_data).decode("utf-8")
|
||||
|
||||
conversation, mm_data, mm_uuids = parse_chat_messages(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Describe this audio"},
|
||||
{
|
||||
"type": "audio_embeds",
|
||||
"audio_embeds": base64_audio_embedding,
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
audio_embeds_model_config,
|
||||
qwen2_audio_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
# Should have audio embedding in mm_data (single tensor, not a list)
|
||||
assert mm_data is not None
|
||||
assert "audio" in mm_data
|
||||
assert isinstance(mm_data["audio"], torch.Tensor)
|
||||
assert mm_data["audio"].shape == audio_embedding.shape
|
||||
# No UUID provided
|
||||
assert mm_uuids is not None
|
||||
assert "audio" in mm_uuids
|
||||
_assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[None])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_audio_embeds_async(
|
||||
audio_embeds_model_config,
|
||||
qwen2_audio_tokenizer,
|
||||
):
|
||||
"""Test audio_embeds with async futures."""
|
||||
import base64
|
||||
import io
|
||||
|
||||
import torch
|
||||
|
||||
# Create a sample audio embedding tensor
|
||||
audio_embedding = torch.randn(1, 128, 768)
|
||||
|
||||
# Encode it as base64
|
||||
buffer = io.BytesIO()
|
||||
torch.save(audio_embedding, buffer)
|
||||
buffer.seek(0)
|
||||
binary_data = buffer.read()
|
||||
base64_audio_embedding = base64.b64encode(binary_data).decode("utf-8")
|
||||
|
||||
conversation, mm_future, mm_uuids = parse_chat_messages_futures(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Describe this audio"},
|
||||
{
|
||||
"type": "audio_embeds",
|
||||
"audio_embeds": base64_audio_embedding,
|
||||
},
|
||||
],
|
||||
}
|
||||
],
|
||||
audio_embeds_model_config,
|
||||
qwen2_audio_tokenizer,
|
||||
content_format="string",
|
||||
)
|
||||
|
||||
# Should have audio embedding in mm_data (single tensor, not a list)
|
||||
mm_data = await mm_future
|
||||
assert mm_data is not None
|
||||
assert "audio" in mm_data
|
||||
assert isinstance(mm_data["audio"], torch.Tensor)
|
||||
assert mm_data["audio"].shape == audio_embedding.shape
|
||||
# No UUID provided
|
||||
assert mm_uuids is not None
|
||||
assert "audio" in mm_uuids
|
||||
_assert_mm_uuids(mm_uuids, 1, modality="audio", expected_uuids=[None])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_parse_chat_messages_empty_image_embeds_with_uuid_async(
|
||||
phi3v_model_config_image_embeds,
|
||||
|
||||
Reference in New Issue
Block a user