[Core][Bugfix] Fix Online MM Beam Search (#19688)

Signed-off-by: Alex-Brooks <Alex.Brooks@ibm.com>
This commit is contained in:
Alex Brooks
2025-06-19 11:18:07 -06:00
committed by GitHub
parent 01220ce89a
commit ead2110297
3 changed files with 45 additions and 12 deletions

View File

@@ -25,6 +25,25 @@ TEST_IMAGE_URLS = [
"https://upload.wikimedia.org/wikipedia/commons/0/0b/RGBA_comp.png",
]
EXPECTED_MM_BEAM_SEARCH_RES = [
[
"The image shows a wooden boardwalk leading through a",
"The image shows a wooden boardwalk extending into a",
],
[
"The image shows two parrots perched on",
"The image shows two birds perched on a cur",
],
[
"The image shows a Venn diagram with three over",
"This image shows a Venn diagram with three over",
],
[
"This image displays a gradient of colors ranging from",
"This image displays a gradient of colors transitioning from",
],
]
@pytest.fixture(scope="module")
def server():
@@ -270,10 +289,13 @@ async def test_single_chat_session_image_base64encoded(
@pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME])
@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS)
@pytest.mark.parametrize("image_idx", list(range(len(TEST_IMAGE_URLS))))
async def test_single_chat_session_image_base64encoded_beamsearch(
client: openai.AsyncOpenAI, model_name: str, image_url: str,
client: openai.AsyncOpenAI, model_name: str, image_idx: int,
base64_encoded_image: dict[str, str]):
# NOTE: This test also validates that we pass MM data through beam search
image_url = TEST_IMAGE_URLS[image_idx]
expected_res = EXPECTED_MM_BEAM_SEARCH_RES[image_idx]
messages = [{
"role":
@@ -297,10 +319,11 @@ async def test_single_chat_session_image_base64encoded_beamsearch(
messages=messages,
n=2,
max_completion_tokens=10,
temperature=0.0,
extra_body=dict(use_beam_search=True))
assert len(chat_completion.choices) == 2
assert chat_completion.choices[
0].message.content != chat_completion.choices[1].message.content
for actual, expected_str in zip(chat_completion.choices, expected_res):
assert actual.message.content == expected_str
@pytest.mark.asyncio