[Mistral-Small 3.1] Update docs and tests (#14977)

Signed-off-by: Roger Wang <ywang@roblox.com>
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Patrick von Platen
2025-03-18 11:29:42 +01:00
committed by GitHub
parent 400d483e87
commit f863ffc965
5 changed files with 34 additions and 60 deletions

View File

@@ -4,7 +4,6 @@
Run `pytest tests/models/test_mistral.py`.
"""
import json
import uuid
from dataclasses import asdict
from typing import TYPE_CHECKING, Any, Optional
@@ -16,8 +15,7 @@ from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.multimodal import image_from_chunk
from transformers import AutoProcessor
from vllm import (EngineArgs, LLMEngine, RequestOutput, SamplingParams,
TextPrompt, TokensPrompt)
from vllm import RequestOutput, SamplingParams, TextPrompt, TokensPrompt
from vllm.multimodal import MultiModalDataBuiltins
from vllm.multimodal.inputs import PlaceholderRange
from vllm.sequence import Logprob, SampleLogprobs
@@ -28,7 +26,11 @@ from ...utils import check_logprobs_close
if TYPE_CHECKING:
from _typeshed import StrPath
MODELS = ["mistralai/Pixtral-12B-2409"]
PIXTRAL_ID = "mistralai/Pixtral-12B-2409"
MISTRAL_SMALL_3_1_ID = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
MODELS = [PIXTRAL_ID, MISTRAL_SMALL_3_1_ID]
IMG_URLS = [
"https://picsum.photos/id/237/400/300",
"https://picsum.photos/id/231/200/300",
@@ -125,8 +127,10 @@ MAX_MODEL_LEN = [8192, 65536]
FIXTURES_PATH = VLLM_PATH / "tests/models/fixtures"
assert FIXTURES_PATH.exists()
FIXTURE_LOGPROBS_CHAT = FIXTURES_PATH / "pixtral_chat.json"
FIXTURE_LOGPROBS_ENGINE = FIXTURES_PATH / "pixtral_chat_engine.json"
FIXTURE_LOGPROBS_CHAT = {
PIXTRAL_ID: FIXTURES_PATH / "pixtral_chat.json",
MISTRAL_SMALL_3_1_ID: FIXTURES_PATH / "mistral_small_3_chat.json",
}
OutputsLogprobs = list[tuple[list[int], str, Optional[SampleLogprobs]]]
@@ -166,12 +170,12 @@ def test_chat(
model: str,
dtype: str,
) -> None:
EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_CHAT)
EXPECTED_CHAT_LOGPROBS = load_outputs_w_logprobs(
FIXTURE_LOGPROBS_CHAT[model])
with vllm_runner(
model,
dtype=dtype,
tokenizer_mode="mistral",
enable_chunked_prefill=False,
max_model_len=max_model_len,
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
) as vllm_model:
@@ -183,70 +187,40 @@ def test_chat(
outputs.extend(output)
logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
# Remove last `None` prompt_logprobs to compare with fixture
for i in range(len(logprobs)):
assert logprobs[i][-1] is None
logprobs[i] = logprobs[i][:-1]
check_logprobs_close(outputs_0_lst=EXPECTED_CHAT_LOGPROBS,
outputs_1_lst=logprobs,
name_0="h100_ref",
name_1="output")
@large_gpu_test(min_gb=80)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
def test_model_engine(vllm_runner, model: str, dtype: str) -> None:
EXPECTED_ENGINE_LOGPROBS = load_outputs_w_logprobs(FIXTURE_LOGPROBS_ENGINE)
args = EngineArgs(
model=model,
tokenizer_mode="mistral",
enable_chunked_prefill=False,
limit_mm_per_prompt=LIMIT_MM_PER_PROMPT,
dtype=dtype,
)
engine = LLMEngine.from_engine_args(args)
engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[0], SAMPLING_PARAMS)
engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[1], SAMPLING_PARAMS)
outputs = []
count = 0
while True:
out = engine.step()
count += 1
for request_output in out:
if request_output.finished:
outputs.append(request_output)
if count == 2:
engine.add_request(uuid.uuid4().hex, ENGINE_INPUTS[2],
SAMPLING_PARAMS)
if not engine.has_unfinished_requests():
break
logprobs = vllm_runner._final_steps_generate_w_logprobs(outputs)
check_logprobs_close(outputs_0_lst=EXPECTED_ENGINE_LOGPROBS,
outputs_1_lst=logprobs,
name_0="h100_ref",
name_1="output")
@large_gpu_test(min_gb=48)
@pytest.mark.parametrize(
"prompt,expected_ranges",
[(_create_engine_inputs_hf(IMG_URLS[:1]), [{
"offset": 10,
"offset": 11,
"length": 494
}]),
(_create_engine_inputs_hf(IMG_URLS[1:4]), [{
"offset": 10,
"offset": 11,
"length": 266
}, {
"offset": 276,
"offset": 277,
"length": 1056
}, {
"offset": 1332,
"offset": 1333,
"length": 418
}])])
def test_multi_modal_placeholders(
vllm_runner, prompt, expected_ranges: list[PlaceholderRange]) -> None:
def test_multi_modal_placeholders(vllm_runner, prompt,
expected_ranges: list[PlaceholderRange],
monkeypatch) -> None:
# This placeholder checking test only works with V0 engine
# where `multi_modal_placeholders` is returned with `RequestOutput`
monkeypatch.setenv("VLLM_USE_V1", "0")
with vllm_runner(
"mistral-community/pixtral-12b",
max_model_len=8192,