[Core][Bugfix] Fix Online MM Beam Search (#19688)
Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
@@ -25,6 +25,25 @@ TEST_IMAGE_URLS = [
|
||||
"https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png",
|
||||
]
|
||||
|
||||
EXPECTED_MM_BEAM_SEARCH_RES = [
|
||||
[
|
||||
"The image shows a wooden boardwalk leading through a",
|
||||
"The image shows a wooden boardwalk extending into a",
|
||||
],
|
||||
[
|
||||
"The image shows two parrots perched on",
|
||||
"The image shows two birds perched on a cur",
|
||||
],
|
||||
[
|
||||
"The image shows a Venn diagram with three over",
|
||||
"This image shows a Venn diagram with three over",
|
||||
],
|
||||
[
|
||||
"This image displays a gradient of colors ranging from",
|
||||
"This image displays a gradient of colors transitioning from",
|
||||
],
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
@@ -270,10 +289,13 @@ async def test_single_chat_session_image_base64encoded(
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
|
||||
@pytest.mark.parametrize("image_idx", list(range(len(TEST_IMAGE_URLS))))
|
||||
async def test_single_chat_session_image_base64encoded_beamsearch(
|
||||
client: openai.AsyncOpenAI, model_name: str, image_url: str,
|
||||
client: openai.AsyncOpenAI, model_name: str, image_idx: int,
|
||||
base64_encoded_image: dict[str, str]):
|
||||
# NOTE: This test also validates that we pass MM data through beam search
|
||||
image_url = TEST_IMAGE_URLS[image_idx]
|
||||
expected_res = EXPECTED_MM_BEAM_SEARCH_RES[image_idx]
|
||||
|
||||
messages = [{
|
||||
"role":
|
||||
@@ -297,10 +319,11 @@ async def test_single_chat_session_image_base64encoded_beamsearch(
|
||||
messages=messages,
|
||||
n=2,
|
||||
max_completion_tokens=10,
|
||||
temperature=0.0,
|
||||
extra_body=dict(use_beam_search=True))
|
||||
assert len(chat_completion.choices) == 2
|
||||
assert chat_completion.choices[
|
||||
0].message.content != chat_completion.choices[1].message.content
|
||||
for actual, expected_str in zip(chat_completion.choices, expected_res):
|
||||
assert actual.message.content == expected_str
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user