[Hotfix][Pixtral] Fix multiple images bugs (#8415)
This commit is contained in:
committed by
GitHub
parent
b61bd98f90
commit
d31174a4e1
BIN
tests/models/fixtures/pixtral_chat.pickle
Normal file
BIN
tests/models/fixtures/pixtral_chat.pickle
Normal file
Binary file not shown.
BIN
tests/models/fixtures/pixtral_chat_engine.pickle
Normal file
BIN
tests/models/fixtures/pixtral_chat_engine.pickle
Normal file
Binary file not shown.
@@ -2,13 +2,128 @@
|
||||
|
||||
Run `pytest tests/models/test_mistral.py`.
|
||||
"""
|
||||
import pytest
|
||||
import pickle
|
||||
import uuid
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
import pytest
|
||||
from mistral_common.protocol.instruct.messages import ImageURLChunk
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||
from mistral_common.tokens.tokenizers.multimodal import image_from_chunk
|
||||
|
||||
from vllm import EngineArgs, LLMEngine, SamplingParams, TokensPrompt
|
||||
from vllm.multimodal import MultiModalDataBuiltins
|
||||
|
||||
from .utils import check_logprobs_close
|
||||
|
||||
pytestmark = pytest.mark.vlm
|
||||
|
||||
MODELS = ["mistralai/Pixtral-12B-2409"]
|
||||
IMG_URLS = [
|
||||
"https://picsum.photos/id/237/400/300",
|
||||
"https://picsum.photos/id/231/200/300",
|
||||
"https://picsum.photos/id/27/500/500",
|
||||
"https://picsum.photos/id/17/150/600",
|
||||
]
|
||||
PROMPT = "Describe each image in one short sentence."
|
||||
|
||||
|
||||
def _create_msg_format(urls: List[str]) -> List[Dict[str, Any]]:
|
||||
return [{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": PROMPT,
|
||||
}] + [{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": url
|
||||
}
|
||||
} for url in urls],
|
||||
}]
|
||||
|
||||
|
||||
def _create_engine_inputs(urls: List[str]) -> TokensPrompt:
|
||||
msg = _create_msg_format(urls)
|
||||
|
||||
tokenizer = MistralTokenizer.from_model("pixtral")
|
||||
|
||||
request = ChatCompletionRequest(messages=msg) # type: ignore[type-var]
|
||||
tokenized = tokenizer.encode_chat_completion(request)
|
||||
|
||||
engine_inputs = TokensPrompt(prompt_token_ids=tokenized.tokens)
|
||||
|
||||
images = []
|
||||
for chunk in request.messages[0].content:
|
||||
if isinstance(chunk, ImageURLChunk):
|
||||
images.append(image_from_chunk(chunk))
|
||||
|
||||
mm_data = MultiModalDataBuiltins(image=images)
|
||||
engine_inputs["multi_modal_data"] = mm_data
|
||||
|
||||
return engine_inputs
|
||||
|
||||
|
||||
MSGS = [
|
||||
_create_msg_format(IMG_URLS[:1]),
|
||||
_create_msg_format(IMG_URLS[:2]),
|
||||
_create_msg_format(IMG_URLS),
|
||||
]
|
||||
ENGINE_INPUTS = [
|
||||
_create_engine_inputs(IMG_URLS[:1]),
|
||||
_create_engine_inputs(IMG_URLS[:2]),
|
||||
_create_engine_inputs(IMG_URLS),
|
||||
]
|
||||
|
||||
SAMPLING_PARAMS = SamplingParams(max_tokens=512, temperature=0.0, logprobs=5)
|
||||
LIMIT_MM_PER_PROMPT = dict(image=4)
|
||||
|
||||
MAX_MODEL_LEN = [8192, 65536]
|
||||
FIXTURE_LOGPROBS_CHAT = "tests/models/fixtures/pixtral_chat.pickle"
|
||||
FIXTURE_LOGPROBS_ENGINE = "tests/models/fixtures/pixtral_chat_engine.pickle"
|
||||
|
||||
|
||||
def load_logprobs(filename: str) -> Any:
|
||||
with open(filename, 'rb') as f:
|
||||
return pickle.load(f)
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason=
|
||||
"Model is too big, test passed on A100 locally but will OOM on CI machine."
|
||||
)
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("max_model_len", MAX_MODEL_LEN)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
def test_chat(
|
||||
vllm_runner,
|
||||
max_model_len: int,
|
||||
model: str,
|
||||
dtype: str,
|
||||
) -> None:
|
||||
EXPECTED_CHAT_LOGPROBS = load_logprobs(FIXTURE_LOGPROBS_CHAT)
|
||||
with vllm_runner(
|
||||
model,
|
||||
dtype=dtype,
|
||||
tokenizer_mode="mistral",
|
||||
enable_chunked_prefill=False,
|
||||
max_model_len=max_model_len,
|
||||
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
|
||||
) as vllm_model:
|
||||
outputs = []
|
||||
for msg in MSGS:
|
||||
output = vllm_model.model.chat(msg,
|
||||
sampling_params=SAMPLING_PARAMS)
|
||||
|
||||
outputs.extend(output)
|
||||
|
||||
logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
|
||||
check_logprobs_close(outputs_0_lst=logprobs,
|
||||
outputs_1_lst=EXPECTED_CHAT_LOGPROBS,
|
||||
name_0="output",
|
||||
name_1="h100_ref")
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
@@ -17,48 +132,37 @@ MODELS = ["mistralai/Pixtral-12B-2409"]
|
||||
)
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [64])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
def test_models(
|
||||
vllm_runner,
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
image_urls = [
|
||||
"https://picsum.photos/id/237/200/300",
|
||||
"https://picsum.photos/seed/picsum/200/300"
|
||||
]
|
||||
expected = [
|
||||
"The image depicts a black dog lying on a wooden surface, looking directly at the camera with a calm expression.", # noqa
|
||||
"The image depicts a serene landscape with a snow-covered mountain under a pastel-colored sky during sunset." # noqa
|
||||
]
|
||||
prompt = "Describe the image in one short sentence."
|
||||
def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
|
||||
EXPECTED_ENGINE_LOGPROBS = load_logprobs(FIXTURE_LOGPROBS_ENGINE)
|
||||
args = EngineArgs(
|
||||
model=model,
|
||||
tokenizer_mode="mistral",
|
||||
enable_chunked_prefill=False,
|
||||
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
|
||||
dtype=dtype,
|
||||
)
|
||||
engine = LLMEngine.from_engine_args(args)
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=512, temperature=0.0)
|
||||
engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[0], SAMPLING_PARAMS)
|
||||
engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[1], SAMPLING_PARAMS)
|
||||
|
||||
with vllm_runner(model, dtype=dtype,
|
||||
tokenizer_mode="mistral") as vllm_model:
|
||||
outputs = []
|
||||
count = 0
|
||||
while True:
|
||||
out = engine.step()
|
||||
count += 1
|
||||
for request_output in out:
|
||||
if request_output.finished:
|
||||
outputs.append(request_output)
|
||||
|
||||
for i, image_url in enumerate(image_urls):
|
||||
messages = [
|
||||
{
|
||||
"role":
|
||||
"user",
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
}, {
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": image_url
|
||||
}
|
||||
}]
|
||||
},
|
||||
]
|
||||
if count == 2:
|
||||
engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[2],
|
||||
SAMPLING_PARAMS)
|
||||
if not engine.has_unfinished_requests():
|
||||
break
|
||||
|
||||
outputs = vllm_model.model.chat(messages,
|
||||
sampling_params=sampling_params)
|
||||
assert outputs[0].outputs[0].text == expected[i]
|
||||
logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
|
||||
check_logprobs_close(outputs_0_lst=logprobs,
|
||||
outputs_1_lst=EXPECTED_ENGINE_LOGPROBS,
|
||||
name_0="output",
|
||||
name_1="h100_ref")
|
||||
|
||||
Reference in New Issue
Block a user